diff --git a/.venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d046c83f02969d2dcb10fd175ac5de23f4c643 Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89511bf0e72f72b79be38c3f4d555ef0c649a357 Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4bc0556420a0b9fe9ae28d0d0b2cb9d164d3c21 Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66b077b80b394addb4bc8fec8d99cd20f53ac2d Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..165cb48ab9e96ef717cd472b8614a0404e8e09ef Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0aff0432293bb13a862d40b1bcf3935fbd98dc4 Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py new file mode 100644 index 0000000000000000000000000000000000000000..347e169d1ac234d8139f64b51fa5304f09464654 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py @@ -0,0 +1,296 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingA100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if context.get_value('arith_intensity') <= 52.6245059967041: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 312.0: + return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)] + else: + if context.get_value('k') <= 40.0: + return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)] + else: + return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)] + else: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)] + else: + if context.get_value('k') <= 68.0: + return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)] + else: + return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)] + else: + if context.get_value('k') <= 35.0: + if context.get_value('k') <= 18.0: + if context.get_value('m*n') <= 19505152.0: + return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)] + else: + return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)] + else: + if context.get_value('n') <= 68.0: + return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)] + else: + return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)] + else: + if context.get_value('m*n') <= 309760.0: + return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)] + else: + if context.get_value('n') <= 72.0: + return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)] + else: + return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 815360.0: + if context.get_value('k') <= 1184.0: + return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)] + else: + return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)] + else: + if context.get_value('arith_intensity') <= 187.23922729492188: + if context.get_value('mat1_stride_0') <= 198.0: + return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)] + else: + return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)] + else: + return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)] + else: + return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)] diff --git a/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py new file mode 100644 index 0000000000000000000000000000000000000000..f1df85add5b84308c3d3cc121dd4713336ac9869 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py @@ -0,0 +1,321 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MMRankingH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4') + + def get_name(self) -> str: + return 'mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if context.get_value('arith_intensity') <= 29.89772129058838: + if context.get_value('n') <= 34.0: + if context.get_value('n') <= 18.0: + if context.get_value('k*n') <= 432.0: + if context.get_value('arith_intensity') <= 7.8700292110443115: + return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)] + else: + return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)] + else: + if context.get_value('k') <= 40.0: + return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)] + else: + return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)] + else: + if context.get_value('mat1_stride_0') <= 40.0: + if context.get_value('mat1_stride_0') <= 20.0: + return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)] + else: + return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)] + else: + if context.get_value('mat1_stride_0') <= 68.0: + return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)] + else: + return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)] + else: + if context.get_value('k') <= 18.0: + if context.get_value('m*k') <= 528.0: + return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)] + else: + if context.get_value('n') <= 80.0: + return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)] + else: + return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)] + else: + if context.get_value('k') <= 36.0: + if context.get_value('n') <= 68.0: + return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)] + else: + return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)] + else: + if context.get_value('mat2_stride_0') <= 384.0: + return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)] + else: + return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)] + else: + if context.get_value('arith_intensity') <= 56.995582580566406: + if context.get_value('n') <= 68.0: + if context.get_value('k*n') <= 4448.0: + if context.get_value('m*n') <= 29626368.0: + return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)] + else: + return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)] + else: + if context.get_value('k') <= 348.0: + return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)] + else: + return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)] + else: + if context.get_value('m') <= 3264.0: + return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)] + else: + if context.get_value('k') <= 62.5: + return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)] + else: + return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)] + else: + if context.get_value('m*n') <= 1097728.0: + return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)] + else: + if context.get_value('m*n') <= 3244032.0: + return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)] + else: + if context.get_value('n') <= 136.0: + return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)] + else: + return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)] diff --git a/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py new file mode 100644 index 0000000000000000000000000000000000000000..64f0c59905350618cff3f359657f681733e6f9df --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -0,0 +1,149 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ +from typing import List, Optional, Tuple + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class MixedMMH100(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: List[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.0 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('extern_fallback_mixed_mm') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4') + self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8') + + def get_name(self) -> str: + return 'mixed_mm' + + def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]: + if context.get_value('arith_intensity') <= 15.988086223602295: + if context.get_value('n') <= 25280.0: + if context.get_value('n') <= 1344.0: + if context.get_value('mat1_stride_0') <= 7808.0: + return [(0.581, 7), (0.419, 6)] + else: + if context.get_value('m*n') <= 7680.0: + return [(0.875, 0), (0.125, 6)] + else: + return [(0.833, 0), (0.167, 7)] + else: + if context.get_value('n') <= 8512.0: + if str(context.get_value('mat2_dtype')) != 'torch.int8': + return [(0.763, 6), (0.237, 7)] + else: + return [(0.725, 7), (0.275, 6)] + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)] + else: + return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)] + else: + if context.get_value('n') <= 42254.0: + if context.get_value('n') <= 33856.0: + if context.get_value('k*n') <= 68157440.0: + return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)] + else: + return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)] + else: + return [(0.659, 5), (0.341, 6)] + else: + if context.get_value('k*n') <= 326052992.0: + if context.get_value('n') <= 55232.0: + return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)] + else: + return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)] + else: + if context.get_value('n') <= 57024.0: + return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)] + else: + return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)] + else: + if context.get_value('m*n') <= 543936.0: + if str(context.get_value('17LEQmLEQ32')) != 'True': + if context.get_value('m*n') <= 262272.0: + if context.get_value('n') <= 1592.5: + return [(0.860, 0), (0.140, 9)] + else: + return None + else: + if context.get_value('m*k') <= 1294336.0: + return [(0.833, 17), (0.150, 18), (0.017, 15)] + else: + return [(0.917, 17), (0.083, 8)] + else: + if context.get_value('n') <= 12416.0: + if context.get_value('m*n') <= 43008.0: + return None + else: + return [(0.853, 14), (0.147, 9)] + else: + return [(0.625, 12), (0.375, 14)] + else: + if context.get_value('m') <= 32.5: + if context.get_value('mat2_stride_1') <= 6656.0: + if context.get_value('n') <= 69184.0: + return [(0.611, 12), (0.361, 14), (0.028, 13)] + else: + return [(1.000, 12)] + else: + if context.get_value('mat2_stride_1') <= 20864.0: + return [(1.000, 12)] + else: + return [(0.958, 12), (0.042, 9)] + else: + if context.get_value('m*n') <= 1085440.0: + if context.get_value('n') <= 9152.0: + return [(1.000, 18)] + else: + return [(0.780, 18), (0.160, 16), (0.060, 20)] + else: + if context.get_value('m') <= 67.0: + return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)] + else: + return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)] diff --git a/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1d390ff498b1bc18df73e0d8eb240a615d0788 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -0,0 +1,109 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicRegression, +) + + +class PadMMA100(LearnedHeuristicRegression): + + def __init__(self) -> None: + pass + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 166912 + and str(metadata.device_capa) == "(8, 0)" + ) + + def get_feedback(self, context: AHContext, choice: Choice) -> float: + context.context_dict[CHOICE_COL] = choice + return self.predict(context) + + def get_confidence_threshold(self) -> float: + return 1.7025303314066 + + def get_name(self) -> str: + return 'pad_mm' + + def predict(self, context: AHContext) -> float: + if str(context.get_value('choice')) != 'pad': + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 4171264.0: + if context.get_value('m*k') <= 3999308.0: + return 1.8751469764071178 + else: + if str(context.get_value('n_multiple_32')) != 'True': + return 0.9117231355626345 + else: + return 1.1607689608873861 + else: + if str(context.get_value('n_multiple_2')) != 'True': + if str(context.get_value('using_tf32')) != 'True': + return 0.7430382200435992 + else: + return 0.8531269794448678 + else: + if str(context.get_value('k_multiple_2')) != 'True': + return 0.7577181972719917 + else: + return 0.8977349440424219 + else: + if context.get_value('m*n') <= 1299712.0: + return 1.1669723418995592 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + if context.get_value('m*n') <= 55884158.0: + return 1.0262769936909601 + else: + return 1.0022677428470845 + else: + if context.get_value('m') <= 18478.0: + return 1.1127066261894312 + else: + return 1.0337740659894263 + else: + if str(context.get_value('mat1_dtype')) != 'torch.float32': + if str(context.get_value('n_multiple_2')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('mat1_stride_0') <= 561.0: + return 1.2900382135142956 + else: + return 1.5761737616057887 + else: + if context.get_value('num_dims_needs_padding') <= 1.5: + return 1.0472263310239422 + else: + return 1.1727673465762514 + else: + if context.get_value('k') <= 28238.5: + if context.get_value('k/(m*n)') <= 0.00026227018679492176: + return 1.6770542505397175 + else: + return 1.3974785435105923 + else: + if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': + return 1.3952699800111992 + else: + return 1.5759286511628336 + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 14119424.0: + return 0.8875772670422478 + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return 1.1467728924377265 + else: + return 1.215842963532998 + else: + if context.get_value('arith_intensity') <= 396.8774871826172: + return 0.89940161869551 + else: + if context.get_value('mat2_stride_1') <= 45217.5: + return 0.9964328169353532 + else: + return 0.9493479238294826 diff --git a/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py b/.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py b/.venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fec11ac879c1ce0147086e52754af354083111f7 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +import re + +import torch +from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + + +# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: +# "... +# from ..codecache import CudaKernelParamCache +# ..." +# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache + + +def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str: + if torch.version.hip is None and not force_hipify: + return source_codes + + def c2_repl(m): + return PYTORCH_MAP[m.group(0)] + + # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch, + # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching + # keyword at the beginning of code line. However, this can happen in codegen, + # which will cause the pattern to not match. + + # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example + # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA" + RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)") + + source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes) + return source_codes diff --git a/.venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py b/.venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py new file mode 100644 index 0000000000000000000000000000000000000000..c9850cf51c9f80322a447ce8764ee8a9dd99fe63 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py @@ -0,0 +1,91 @@ +import torch + + +# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose + + +def cuda_kernel_driver() -> str: + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + +def cuda_kernel_header() -> str: + source_codes = """ + #include + #include + #include + """ + return source_codes diff --git a/.venv/Lib/site-packages/torch/_inductor/codegen/common.py b/.venv/Lib/site-packages/torch/_inductor/codegen/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3da9846c641eed0e261e02f2bac50bd58c9b9111 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/codegen/common.py @@ -0,0 +1,2167 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import re +from enum import auto, Enum +from itertools import chain +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import sympy +from sympy.printing.printer import Printer + +import torch +import torch.fx +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges + +from .. import config, metrics +from ..utils import ( + DeferredLineBase, + generate_assert, + IndentedBuffer, + sympy_dot, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + + +def data_type_logger(msg): + if schedule_log.isEnabledFor(logging.DEBUG): + schedule_log.debug("Data type propagation: %s", msg) + + +@dataclasses.dataclass +class WorkspaceArg: + """A temporary buffer used for a single kernel, then discarded. + + Not registered as a traditional buffer since there are no users, + so it would be dead code eliminated. + """ + + nbytes: sympy.Expr + zero_fill: bool + + +@dataclasses.dataclass +class TensorArg: + name: str + buffer: str + dtype: torch.dtype + offset: sympy.Expr = sympy.Integer(0) # c++ only + alias_of: Optional[str] = None # halide only + + +@dataclasses.dataclass +class SizeArg: + name: str + expr: sympy.Expr + + @property + def alias_of(self): + return None + + +@dataclasses.dataclass +class DeviceCodegen: + scheduling: Any + wrapper_codegen: type + cpp_wrapper_codegen: type = type(None) + + +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] + +device_codegens: Dict[str, DeviceCodegen] = {} + + +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name): + raise NotImplementedError + + def set_device(self, device_idx): + raise NotImplementedError + + def synchronize(self): + raise NotImplementedError + + def device_guard(self, device_idx): + raise NotImplementedError + + +device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} + + +# The code generated by Inductor consists of two main parts: kernel code and wrapper code. +# For any new backend looking to integrate with Inductor, customization of these two main +# parts are necessary to generate its specific code. +# +# Kernel code generation is determined by different Scheduling. Consequently, a new +# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently, +# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively. +# +# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code +# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen, +# and override specific member functions to create backend-specific Python wrapper code. +# +# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part +# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces +# provide flexibility to the backend. A backend can choose to implement these classes from scratch, +# or reuse them by extending and overriding as necessary. And Inductor provides the registration API, +# register_backend_for_device, to equip a new backend at runtime. +# +# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces. +# This backend can be used as a reference: +# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 +def register_backend_for_device( + device: str, + device_scheduling: Any, + device_wrapper_codegen: type, + device_cpp_wrapper_codegen: type = type(None), +): + device_codegens[device] = DeviceCodegen( + device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen + ) + + +class BackendFeature(Enum): + FOREACH = auto() + BUCKETIZE = auto() + INPLACE_BUFFERS = auto() + MASKED_SCATTER_WITH_INDEX = auto() + SCAN = auto() + SORT = auto() + TUPLE_REDUCTION = auto() + PREFER_STORE_LOOP_ORDER = auto() + TRITON_TEMPLATES = auto() + REDUCE_TO_SINGLE_ELEMENT = auto() + + +def get_backend_features(device: Union[torch.device, str]): + init_backend_registration() + if isinstance(device, torch.device): + device_type = device.type + else: + assert isinstance(device, str) + device_type = device + device = torch.device(device_type) + scheduling = get_scheduling_for_device(device_type) + return scheduling(None).get_backend_features(device) + + +def has_backend_feature(device, feature): + """See also V.graph.has_feature""" + assert isinstance(feature, BackendFeature) + return feature in get_backend_features(device) + + +def get_scheduling_for_device(device: str): + return device_codegens[device].scheduling if device in device_codegens else None + + +def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): + if device in device_codegens: + wrapper_codegen_obj: DeviceCodegen = device_codegens[device] + return ( + wrapper_codegen_obj.cpp_wrapper_codegen + if cpp_wrapper + else wrapper_codegen_obj.wrapper_codegen + ) + else: + return None + + +@functools.lru_cache(None) +def init_backend_registration(): + from .cpp import CppScheduling + from .cpp_wrapper_cpu import CppWrapperCpu + from .cpp_wrapper_cuda import CppWrapperCuda + from .cuda_combined_scheduling import CUDACombinedScheduling + from .halide import HalideScheduling + from .triton import TritonScheduling + from .wrapper import WrapperCodeGen + + if get_scheduling_for_device("cpu") is None: + cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling} + register_backend_for_device( + "cpu", + lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs), + WrapperCodeGen, + CppWrapperCpu, + ) + + if get_scheduling_for_device("cuda") is None: + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling} + register_backend_for_device( + "cuda", + lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs), + WrapperCodeGen, + CppWrapperCuda, + ) + + if get_scheduling_for_device("xpu") is None: + register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) + + private_backend = torch._C._get_privateuse1_backend_name() + if ( + private_backend != "privateuseone" + and get_scheduling_for_device(private_backend) is None + ): + from torch.utils.backend_registration import _get_custom_mod_func + + try: + device_scheduling = _get_custom_mod_func("Scheduling") + wrapper_codegen = _get_custom_mod_func("WrapperCodeGen") + cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen") + if device_scheduling and wrapper_codegen and cpp_wrapper_codegen: + register_backend_for_device( + private_backend, + device_scheduling, + wrapper_codegen, + cpp_wrapper_codegen, + ) + except RuntimeError: + pass + + +def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): + from ..ir import FlexibleLayout + + # added contiguous index prevents reordering + return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] + + +def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): + device_op_overrides_dict[device] = device_op_overrides + + +def get_device_op_overrides(device: str): + assert isinstance(device, str) + + if not device_op_overrides_dict.keys(): + from .cuda import device_op_overrides # noqa: F401 + from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 + + if device in device_op_overrides_dict.keys(): + return device_op_overrides_dict[device] + + +@functools.lru_cache(None) +def boolean_ops(): + return ( + "isinf", + "isnan", + "logical_not", + "signbit", + "le", + "lt", + "ge", + "gt", + "eq", + "ne", + ) + + +DTYPE_TO_COMPUTATION_DTYPE = { + torch.bfloat16: torch.float, + torch.float16: torch.float, + **{ + dtype: dtype + for dtype in [ + torch.bool, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ] + }, +} + + +def deduce_output_dtype_by_name( + op_name: str, + *args, + **kwargs, +) -> Optional[torch.dtype]: + """ + Given op name and a list of input dtypes, deduce the output dtype + """ + if op_name in boolean_ops(): + return torch.bool + elif op_name in ( + "to_dtype", + "index_expr", + ): + return kwargs["dtype"] if "dtype" in kwargs else args[-1] + elif op_name in ( + "rand", + "randn", + ): + return torch.float + elif op_name in ( + "get_index", + "randint64", + "load_seed", + ): + return torch.int64 + elif op_name == "reduction": + return kwargs["dtype"] if "dtype" in kwargs else args[1] + elif op_name == "constant": + dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1] + return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index] + elif op_name in ( + "load", + "store", + "store_reduction", + ): + buf_name = args[1] + return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] + return None + + +class DataTypePropagation: + def __init__(self, body) -> None: + self.body = body + self.graphs: Dict[Union[Callable[..., Any], str], Any] = { + "root": body.root_block.graph + } + for k, v in body.subblocks.items(): + self.graphs[k] = v.graph + + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): + inputs = node.all_input_nodes + input_nodes = [ + n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" + ] + if len(input_nodes) == 0: + return None + + all_input_nodes_propagated = all( + OptimizationContext.key in n.meta + and n.meta[OptimizationContext.key].dtype is not None + for n in input_nodes + ) + if not all_input_nodes_propagated: + return None + + return functools.reduce( + torch.promote_types, + [n.meta[OptimizationContext.key].dtype for n in input_nodes], + ) + + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): + sub_graph = self.graphs[node.target] + dtype = self.propagate_graph(sub_graph) + assert dtype + return dtype + + def deduce_node_dtype(self, node: torch.fx.Node): + if node.op == "placeholder": + return None + + if node.target == "output" and len(node.args) != 1: + # we can infer output node if it only have 1 arg + return None + + if node.target == operator.getitem: + return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type] + + assert isinstance(node.target, str) + + if node.target.startswith("masked_subblock"): + return self.deduce_node_dtype_by_subgraph(node) + + if ( + output_dtype := deduce_output_dtype_by_name( + node.target, + *node.args, + **node.kwargs, + ) + ) is not None: + return output_dtype + + return self.deduce_node_dtype_by_inputs(node) + + def propagate_graph(self, graph: torch.fx.Graph): + assert graph.nodes + graph_dtype = None + # For masked_subblock, we use output's dtype to represent + # the dtype of this subgraph. For other cases, graph_dtype + # might be None + for node in graph.nodes: + if OptimizationContext.key in node.meta: + opt_ctx = node.meta[OptimizationContext.key] + else: + opt_ctx = OptimizationContext() + + opt_ctx.dtype = self.deduce_node_dtype(node) + node.meta[OptimizationContext.key] = opt_ctx + if node.target == "output": + graph_dtype = opt_ctx.dtype + return graph_dtype + + def propagate(self): + self.propagate_graph(self.graphs["root"]) + + @classmethod + def propagate_loopbody(cls, body): + return cls(body).propagate() + + @classmethod + def propagate_scheduler_node(cls, node): + from ..loop_body import LoopBody + from ..scheduler import SchedulerNode + + assert isinstance(node, SchedulerNode) + assert isinstance(node._body, LoopBody) + DataTypePropagation.propagate_loopbody(node._body) + + +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python +class ExprPrinter(Printer): + @staticmethod + def paren(string): + def all_in_parens(string): + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + if ( + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE) + or re.match(r"^\([^)]*\)$", string, re.IGNORECASE) + or string == "" + ): + return string + # don't put extra parens for strings that are already wrapped in parens + if all_in_parens(string): + return string + return f"({string})" + + def _print_Relational(self, expr): + return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) + + def _print_Mul(self, expr): + return "*".join(map(self.paren, map(self._print, expr.args))) + + def _print_Add(self, expr): + return " + ".join(map(self.paren, map(self._print, expr.args))) + + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent + def _print_Mod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_CleanDiv(self, expr): + return self._print_FloorDiv(expr) + + def _print_Identity(self, expr): + return self._print(expr.args[0]) + + def _print_GreaterThan(self, expr): + # GreaterThan: >= + # StrictlyGreaterThan: > + # Go figure... + return " >= ".join(map(self.paren, map(self._print, expr.args))) + + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py + def _print_align(self, expr): + assert len(expr.args) == 1 + return f"align({self._print(expr.args[0])})" + + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + + def doprint(self, expr, *, simplify: bool = True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + mod = self.paren(self.doprint(mod)) + if div != "1": + x = f"({x} // {div})" + return f"{x} % {mod}" + + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _helper_sqrt(self, expr): + return f"math.sqrt({self._print(expr)})" + + def _print_OpaqueUnaryFn_sqrt(self, expr): + return self._helper_sqrt(expr.args[0]) + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float + return f"math.trunc({self._print(expr.args[0])})" + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"abs({self._print(expr.args[0])})" + + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion + def _print_Max(self, expr): + assert len(expr.args) >= 2 + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr): + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr): + assert len(expr.args) == 1 + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + assert isinstance(ndigits, sympy.Integer) + return f"round({self._print(number)}, {ndigits})" + + +class OpOverrides: + def __init__(self, parent): + super().__init__() + self._parent = parent + + def __getattr__(self, item): + return getattr(self._parent, item) + + @staticmethod + def identity(value): + # used to trigger cse + return value + + @staticmethod + def constant(value, dtype): + return repr(value) + + @staticmethod + def reciprocal(x): + return ops.truediv(ops.constant(1, torch.int32), x) + + @staticmethod + def square(x): + return ops.mul(x, x) + + @staticmethod + def erfc(x): + return ops.sub(ops.constant(1, torch.float32), ops.erf(x)) + + @staticmethod + def erfcx(x): + return ops.mul(ops.exp(ops.square(x)), ops.erfc(x)) + + @staticmethod + def expm1(x): + return ops.sub(ops.exp(x), ops.constant(1, torch.float32)) + + @staticmethod + def log10(x): + return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32)) + + @staticmethod + def log2(x): + return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32)) + + @staticmethod + def exp2(x): + return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32))) + + @staticmethod + def log1p(x): + return ops.log(ops.add(x, ops.constant(1, torch.int32))) + + @staticmethod + def sigmoid(x): + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + + @staticmethod + def libdevice_sigmoid(x): + one = ops.constant(1, torch.int32) + return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x)))) + + @staticmethod + def relu(x): + return ops.maximum(x, ops.constant(0, torch.int32)) + + @staticmethod + def libdevice_abs(x): + return ops.abs(x) + + @staticmethod + def libdevice_sqrt(x): + return ops.sqrt(x) + + @staticmethod + def libdevice_cos(x): + return ops.cos(x) + + @staticmethod + def libdevice_sin(x): + return ops.sin(x) + + @staticmethod + def libdevice_log(x): + return ops.log(x) + + @staticmethod + def libdevice_exp(x): + return ops.exp(x) + + @staticmethod + def bitwise_not(x): + return f"~{ExprPrinter.paren(x)}" + + @staticmethod + def logical_not(a): + return f"{ExprPrinter.paren(a)} == 0" + + @staticmethod + def bitwise_and(x, y): + return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_or(x, y): + return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_xor(x, y): + return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_left_shift(x, y): + return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" + + @staticmethod + def bitwise_right_shift(x, y): + return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" + + @staticmethod + def remainder(a, b): + r = ops.mod(a, b) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) + + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + + @staticmethod + def load_seed(name, offset): + return ops.load(name, sympy.Integer(offset)) + + @classmethod + def _initialize_pointwise_overrides(cls, target): + assert target in {"triton", "cpp", "cppvec"}, target + + for funcname, data in pointwise_overrides_data.items(): + impl = getattr(data, target) + if impl is None: + continue + setattr(cls, funcname, staticmethod(impl)) + + +@dataclasses.dataclass +class OverridesData: + name: str + cpp: Callable[..., str] + # None when not impl in libdevice/triton + triton: Optional[Callable[..., str]] = None + # None when not impl in aten/.../vec + cppvec: Optional[Callable[..., str]] = None + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +# NB: if you add a new special function, don't forget to update +# torch._inductor.ops_handler too +pointwise_overrides_data: Dict[str, OverridesData] = dict( + airy_ai=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"airy_ai_forward({x})", + name="special_airy_ai", + ), + bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j0_forward({x})", + triton=lambda x: f"libdevice.j0({x})", + name="special_bessel_j0", + ), + bessel_j1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_j1_forward({x})", + triton=lambda x: f"libdevice.j1({x})", + name="special_bessel_j1", + ), + bessel_y0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y0_forward({x})", + triton=lambda x: f"libdevice.y0({x})", + name="special_bessel_y0", + ), + bessel_y1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"bessel_y1_forward({x})", + triton=lambda x: f"libdevice.y1({x})", + name="special_bessel_y1", + ), + digamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_digamma({x})", + cppvec=lambda x: f"{x}.digamma()", + name="digamma", + ), + # no cpp nor triton implementation for entr, it is defined as decomposition + # erf, erfc + erfcx=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_erfcx({x})", + triton=lambda x: f"libdevice.erfcx({x})", + name="special_erfcx", + ), + fma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})", + cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})", + triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})", + name="fma", + ), + # erfinv, exp2, expit, gammaln + igamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="igamma", + ), + igammac=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="igammac", + ), + gammainc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igamma({x}, {y})", + name="special_gammainc", + ), + gammaincc=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_igammac({x}, {y})", + name="special_gammaincc", + ), + i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + cppvec=lambda x: f"{x}.i0()", + name="i0", + ), + i0e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i0e({x})", + cppvec=lambda x: f"{x}.i0e()", + name="special_i0e", + ), + i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_i1", + ), + i1e=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_i1e({x})", + name="special_i1e", + ), + log_ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_log_ndtr({x})", + name="special_log_ndtr", + ), + # logit + modified_bessel_i0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i0_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i0({x})", + name="special_modified_bessel_i0", + ), + modified_bessel_i1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_i1_forward({x})", + triton=lambda x: f"libdevice.cyl_bessel_i1({x})", + name="special_modified_bessel_i1", + ), + modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k0_forward({x})", + name="special_modified_bessel_k0", + ), + modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"modified_bessel_k1_forward({x})", + name="special_modified_bessel_k1", + ), + # multigamma + ndtr=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtr({x})", + name="special_ndtr", + ), + ndtri=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"calc_ndtri({x})", + name="special_ndtri", + ), + polygamma=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"calc_polygamma({y}, {x})", + name="polygamma", + ), + # psi - alias to digamma + # round + scaled_modified_bessel_k0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})", + name="special_scaled_modified_bessel_k0", + ), + scaled_modified_bessel_k1=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})", + name="special_scaled_modified_bessel_k1", + ), + # sinc + spherical_bessel_j0=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x: f"spherical_bessel_j0_forward({x})", + name="special_spherical_bessel_j0", + ), + zeta=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"zeta({x}, {y})", + name="special_zeta", + ), + chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})", + name="special_chebyshev_polynomial_t", + ), + chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})", + name="special_chebyshev_polynomial_u", + ), + chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})", + name="special_chebyshev_polynomial_v", + ), + chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})", + name="special_chebyshev_polynomial_w", + ), + legendre_polynomial_p=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})", + name="special_legendre_polynomial_p", + ), + shifted_chebyshev_polynomial_t=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_t", + ), + shifted_chebyshev_polynomial_u=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_u", + ), + shifted_chebyshev_polynomial_v=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_v", + ), + shifted_chebyshev_polynomial_w=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})", + name="special_shifted_chebyshev_polynomial_w", + ), + hermite_polynomial_h=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})", + name="special_hermite_polynomial_h", + ), + hermite_polynomial_he=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})", + name="special_hermite_polynomial_he", + ), + laguerre_polynomial_l=OverridesData( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})", + name="special_laguerre_polynomial_l", + ), +) + + +# Use mypy to check protocol implemented correctly +def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: + return h + + +class DeferredLine(DeferredLineBase): + """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" + + def __init__(self, name, line): + super().__init__(line) + self.name = name + assert not isinstance(line, DeferredLineBase) + + def __call__(self): + if all( + self.name not in x + for x in ( + V.graph.removed_buffers, + V.kernel.removed_buffers, + V.graph.inplaced_to_remove, + V.kernel.inplaced_to_remove, + ) + ): + return self.line + return None + + def _new_line(self, line): + return DeferredLine(self.name, line) + + +class BracesBuffer(IndentedBuffer): + def indent(self, offset=1): + @contextlib.contextmanager + def ctx(): + for _ in range(offset): + self.writeline("{") + self._indent += 1 + for _ in range(-offset): + self._indent -= 1 + self.writeline("}") + yield + for _ in range(-offset): + self.writeline("{") + self._indent += 1 + for _ in range(offset): + self._indent -= 1 + self.writeline("}") + + return ctx() + + +class InplacedBuffer(NamedTuple): + inner_name: str + other_names: List[str] + + +class KernelArgs: + @staticmethod + def _lookup(prefix, odict, name): + assert isinstance(name, (str, sympy.Symbol)) + if name not in odict: + odict[name] = f"{prefix}{len(odict)}" + return odict[name] + + def __init__(self, sizevars=None): + self.input_buffers = {} + self.output_buffers = {} + self.inplace_buffers = {} + self.sizevars = sizevars or {} + self.workspace_arg = None + + def __repr__(self): + return "KernelArgs({})".format( + ", ".join( + map( + repr, + [ + self.input_buffers, + self.output_buffers, + self.inplace_buffers, + self.sizevars, + ], + ) + ) + ) + + def _buffer_is_marked_removed(self, name): + return isinstance(name, str) and name.startswith("REMOVED") + + def input(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.output_buffers: + return self.output_buffers[name] + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + if name.startswith("seed"): + return self._lookup("seed", self.input_buffers, name) + return self._lookup("in_ptr", self.input_buffers, name) + + def output(self, name): + if V.graph.scheduler: + name = V.graph.scheduler.mutation_real_name.get(name, name) + assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name + return self._lookup("out_ptr", self.output_buffers, name) + + def make_inplace(self, input_name, output_name): + assert output_name not in self.inplace_buffers + if input_name in self.inplace_buffers: + buf = self.inplace_buffers[input_name] + buf.other_names.append(output_name) + self.inplace_buffers[output_name] = buf + else: + buf = InplacedBuffer( + f"in_out_ptr{len(unique(self.inplace_buffers.values()))}", + [input_name, output_name], + ) + self.inplace_buffers[input_name] = buf + self.inplace_buffers[output_name] = buf + + def workspace(self, nbytes: sympy.Expr, zero_fill: bool): + if self.workspace_arg is None: + self.workspace_arg = WorkspaceArg(nbytes, zero_fill) + return "ws_ptr", 0 + + offset = self.workspace_arg.nbytes + zero_fill = zero_fill or self.workspace_arg.zero_fill + self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) + return "ws_ptr", offset + + def seed_offset(self, name, value): + if value in self.sizevars: + return self.sizevars[value] + if name in self.sizevars.values(): + name = ( + f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}" + ) + self.sizevars[value] = name + return name + + def size(self, name): + if str(name) == "seed": + self.sizevars["seed"] = "seed" + return "seed" + return self._lookup("ks", self.sizevars, name) + + def call_names(self): + return chain( + self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() + ) + + def wrap_ptr_arg(self, buf, dtype): + return buf + + def wrap_size_arg(self, size): + return str(size) + + def cpp_argdefs(self): + from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE + + call_args = [] + arg_defs = [] + arg_types = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + outer = inplaced.other_names[-1] + inner = inplaced.inner_name + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.input_buffers.items(): + if outer in self.inplace_buffers: + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"const {cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + dtype = V.graph.get_dtype(outer) + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") + for outer, inner in self.sizevars.items(): + arg_defs.append(f"const {INDEX_TYPE} {inner}") + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + assert self.workspace_arg is None, "Workspace not supported on CPU " + return arg_defs, call_args, arg_types + + def python_argdefs(self): + arg_defs: List[str] = [] + call_args: List[str] = [] + arg_types: List[torch.dtype] = [] + precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = [] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + arg_defs.append(inplaced.inner_name) + call_args.append(inplaced.other_names[-1]) + arg_types.append(V.graph.get_dtype(inplaced.other_names[-1])) + precompile_args.append( + TensorArg( + name=inplaced.inner_name, + buffer=inplaced.other_names[-1], + dtype=V.graph.get_dtype(inplaced.other_names[-1]), + ) + ) + for outer, inner in chain( + self.input_buffers.items(), self.output_buffers.items() + ): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + arg_defs.append(inner) + call_args.append(outer) + arg_types.append(V.graph.get_dtype(outer)) + precompile_args.append( + TensorArg( + name=inner, + buffer=outer, + dtype=V.graph.get_dtype(outer), + ) + ) + for outer, inner in self.sizevars.items(): + arg_defs.append(inner) + call_args.append(outer) + arg_types.append(type(outer)) # type: ignore[arg-type] + precompile_args.append(SizeArg(inner, outer)) + if V.graph.wrapper_code: + V.graph.wrapper_code.ensure_size_computed(outer) + if self.workspace_arg is not None: + arg_defs.append("ws_ptr") + call_args.append("workspace") + precompile_args.append(self.workspace_arg) + return arg_defs, call_args, precompile_args, arg_types + + def aliases(self): + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + for other in inplaced.other_names: + if ( + other in V.graph.inplaced_to_remove + or other in V.kernel.inplaced_to_remove + ): + continue + if other in self.input_buffers: + yield self.input_buffers[other], inplaced.inner_name + if other in self.output_buffers: + yield self.output_buffers[other], inplaced.inner_name + + def is_removed(self, name): + def _is_removed(name, buffers): + return name not in buffers or self._buffer_is_marked_removed(buffers[name]) + + return _is_removed(name, self.output_buffers) and _is_removed( + name, self.inplace_buffers + ) + + # Includes inplace buffers, excludes removed buffers. Essentially, + # after you do a call into this kernel, which buffers actually contain + # updated data? Modeled off of python_argdefs. + def live_output_buffers(self): + live_outs = OrderedSet() # type: ignore[var-annotated] + for inplaced in unique(self.inplace_buffers.values()): + if self._buffer_is_marked_removed(inplaced): + continue + live_outs.add(inplaced.other_names[-1]) + for outer, inner in self.output_buffers.items(): + if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): + continue + live_outs.add(outer) + return live_outs + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + To do so, the backends can simply overload `Kernel.create_cse_var` + The "CSEVariable.update_on_args" method gives you a hook for annotations + See example of TritonCSEVariable in triton.py + """ + + def __init__(self, name, bounds: ValueRanges[Any]): + assert isinstance(bounds, ValueRanges) + self.name = name + self.bounds = bounds + self.use_count = 1 # track how many tims this expression is used + + def __str__(self): + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + + def update_on_args(self, name, args, kwargs): + pass + + def __repr__(self): + return f"{self.__class__.__name__}({self.name!r})" + + +class CppWrapperKernelArgs(KernelArgs): + def wrap_ptr_arg(self, buf, dtype): + from .cpp_utils import DTYPE_TO_CPP + + if config.abi_compatible: + # In the abi_compatible model, we just return the buf here. + # We will form correct call args later in wrapper.generate_kernel_all. + return buf + else: + return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" + + def wrap_size_arg(self, size): + return f"{size}" + + +class CSE: + """Common subexpression elimination""" + + def __init__( + self, + prefix="", + suffix="", + name_prefix="tmp", + iter_buffers=None, + store_cache=None, + reduction_cache=None, + varname_map=None, + ): + self.prefix = prefix + self.suffix = suffix + self.cache = {} + self.name_prefix = name_prefix + self.store_cache = store_cache or {} + self.reduction_cache = reduction_cache or {} + self.iter_buffer_ids = iter_buffers or itertools.count() + self.invalidated_stores = OrderedSet() # type: ignore[var-annotated] + self.varname_map = varname_map or {} + + def invalidate(self, keep_vars: OrderedSet[str]): + for name, tmp in list(self.store_cache.items()): + if tmp not in keep_vars: + del self.store_cache[name] + self.invalidated_stores.add(name) + self.cache = {k: v for k, v in self.cache.items() if v in keep_vars} + + def clone(self): + # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional + return CSE( + prefix=self.prefix, + suffix=self.suffix, + name_prefix=self.name_prefix, + iter_buffers=self.iter_buffer_ids, + store_cache=self.store_cache, + varname_map=self.varname_map, + ) + + def generate( + self, + buffer: IndentedBuffer, + expr: Union[str, CSEVariable, OpsValue, IndentedBuffer], + *, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + write=True, + assignment=True, + ) -> CSEVariable: + if isinstance(expr, OpsValue): + expr = expr.value + + assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr) + assert write or assignment + if isinstance(expr, CSEVariable): + # If the expressions were always created with all the information, we could + # assert expr.bounds == bounds, but sometimes the expression is created + # with the loose ValueRanges.unknown(), so we need to tighten the bounds + expr.bounds = expr.bounds.tighten(bounds) + expr.use_count += 1 + return expr + cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr + var = self.cache.get(cache_key, None) + if not var: + var = self.newvar(bounds) + self.cache[cache_key] = var + if write: + if V.kernel.current_node: + V.kernel.current_node.codegen_originating_info( + buffer, only_once=True + ) + if isinstance(expr, IndentedBuffer): + if assignment: + buffer.writeline(f"{self.prefix}{var} =") + buffer.splice(expr) + buffer.writeline(self.suffix) + else: + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) + else: + var.bounds = var.bounds.tighten(bounds) + var.use_count += 1 + + return var + + def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name, bounds) + self.varname_map[var_name] = var + return var + + +class CodeGen: + def __init__(self) -> None: + super().__init__() + self.exit_stack = contextlib.ExitStack() + + def __enter__(self): + self.exit_stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class ScopedDict: + def __init__(self, original_dict): + self.original_dict = original_dict + self.new_items = {} + + def __getitem__(self, key): + if key in self.new_items: + return self.new_items[key] + return self.original_dict[key] + + def __setitem__(self, key, value): + self.new_items[key] = value + + def __contains__(self, key): + return key in self.new_items or key in self.original_dict + + def get(self, key, default=None): + if key in self.new_items: + return self.new_items[key] + return self.original_dict.get(key, default) + + +class Kernel(CodeGen): + newvar_prefix = "" + suffix = "" + overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None + # TODO: these look dead, but with all the getattr it's hard to tell... + load_format: None = None + store_format: None = None + + def __init__(self, args=None, increase_kernel_count=True): + super().__init__() + if increase_kernel_count: + metrics.generated_kernel_count += 1 + self.args = args or KernelArgs() + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + + self.num_load = 0 + self.num_reduction = 0 + + self.cse: CSE = CSE(self.newvar_prefix, self.suffix) + self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated] + self.store_buffer_names = OrderedSet() # type: ignore[var-annotated] + self._load_mask = None + self._load_other = None + # OrderedSet in set_current_node + self.current_node = None + self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None + + self.removed_buffers = OrderedSet() # type: ignore[var-annotated] + self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated] + + # key: the buffer to write + # value: the buffer to read and whose memory can be reused for + # the buffer specified by key + self.inplace_update_buffers = {} + # Set minimum number of elements processed per thread. + self.min_elem_per_thread = 1 + self.kernel_name = None + + @contextlib.contextmanager + def set_current_node(self, node): + prior = self.current_node + self.current_node = node + self.node_to_bounds = node._body.bounds().get_bounds() + try: + yield + finally: + self.current_node = prior + + @contextlib.contextmanager + def swap_buffers(self, lb, cb=None, sb=None): + def scope_cse(cse): + new_cse = cse.clone() + new_cse.cache = ScopedDict(cse.cache) + new_cse.reduction_cache = ScopedDict(cse.reduction_cache) + new_cse.store_cache = ScopedDict(cse.store_cache) + return new_cse + + if cb is None: + cb = lb + loads = self.loads + compute = self.compute + stores = self.stores + cse = self.cse + self.loads = lb + self.compute = cb + self.stores = sb + self.cse = scope_cse(cse) + try: + yield + finally: + self.loads = loads + self.compute = compute + self.stores = stores + self.cse = cse + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + raise NotImplementedError + + def indirect_load(self, name: str, index: sympy.Expr): + """A load the depends on an index we have read""" + prior = self.loads + try: + # put the load in the compute section as it might have deps + self.loads = self.compute + return self.load(name, index) + finally: + self.loads = prior + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + raise NotImplementedError + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + raise NotImplementedError + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + raise NotImplementedError + + def scan( + self, + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + raise NotImplementedError + + def sort( + self, + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + raise NotImplementedError + + def var_ranges(self): + raise NotImplementedError + + def bucketize( + self, + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + raise NotImplementedError + + @property + def assert_function(self) -> str: + raise NotImplementedError + + def indirect_assert( + self, + var: Union[CSEVariable, str], + lower: Optional[str], + upper: Optional[str], + mask: Optional[Union[CSEVariable, str]] = None, + ) -> str: + if isinstance(var, CSEVariable): + var = str(var) + assert isinstance(var, str) + assert lower is None or isinstance(lower, str) + assert upper is None or isinstance(upper, str) + if lower and upper: + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less error-prone to use and/or/not, which is suported by triton + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower} <= {var} < {upper}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = cond + else: + assert upper + cond = f"{var} < {upper}" + cond_print = cond + + if mask: + cond = f"({cond}) | ~({mask})" + + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + raise NotImplementedError + + def index_to_str(self, index: sympy.Expr) -> str: + raise NotImplementedError + + def __enter__(self): + # TODO: hoist this to top level + class CSEProxy: + self.name = "CSEProxy" + vr_analysis = ValueRangeAnalysis() + + @staticmethod + def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc] + def inner(*args, **kwargs): + bounds = CSEProxy._bound_variable(name, *args, **kwargs) + + value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] + + def do_cse(v): + csevar = V.kernel.cse.generate( + V.kernel.compute, v, bounds=bounds + ) + csevar.update_on_args(name, args, kwargs) + return csevar + + return pytree.tree_map(do_cse, value) + + return inner + + @staticmethod + def _bound_variable(name, *args, **kwargs): + """ + If the variable comes from an FX node, we forward the bound we have already computed + Else, if the variable when codegen'ing another op, we try to compute its bounds + """ + from ..select_algorithm import TritonTemplateKernel + + if isinstance(V.kernel, TritonTemplateKernel): + return ValueRanges.unknown() + + fx_node = V.interpreter.current_node + if fx_node.target == name and self.node_to_bounds is not None: + assert isinstance(self.node_to_bounds, dict) + return self.node_to_bounds.get(fx_node, ValueRanges.unknown()) + elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name): + # These create lots of inner strings. We would need to compute the bounds at the ops + # We will also likely not get much from computing VRs on these nodes + if any( + s in fx_node.target + for s in ("set_indirect", "reduction", "scan") + ): + return ValueRanges.unknown() + + # We assume that the inputs come from `ops.` and are not strings. If you want to generate + # intermediary strings, wrap them in CSE variables with properly initialised bounds. + + # If there is no FX bound but we know how to compute one we do so + assert not kwargs + + def arg_to_bound(x): + if isinstance(x, CSEVariable): + return x.bounds + elif isinstance(x, sympy.Expr): + return bound_sympy(x) + else: + return x + + arg_bounds = list(map(arg_to_bound, args)) + return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) + else: + return ValueRanges.unknown() + + @staticmethod + def indirect_indexing( + var: CSEVariable, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg=True, + ): + if isinstance(size, int): + size = sympy.Integer(size) + assert isinstance(size, sympy.Expr), size + # Skip CSE since this doesn't return an expression + + if var.bounds.lower < 0: # type: ignore[operator] + if wrap_neg: + stm = ops.add(var, ops.index_expr(size, torch.long)) + # Mixed negative and non-negative + if var.bounds.upper >= 0: # type: ignore[operator] + lt = ops.lt(var, 0) + stm = ops.where(lt, stm, var) + else: + stm = var + + # Propagate bounds as we know how to compute them properly + new_bounds = ValueRanges.unknown() + if var.bounds != ValueRanges.unknown() and isinstance( + size, sympy.Number + ): + # Take the negative part of the bound and add size to it + # Then take union of that and the positive part + # This is a tighter bound than that of a generic ops.where, as we have info on the cond + neg_bounds = var.bounds & ValueRanges(-int_oo, -1) + new_bounds = ValueRanges( + neg_bounds.lower + size, neg_bounds.upper + size + ) + # We don't have a good way of representing the empty range + if var.bounds.upper >= 0: # type: ignore[operator] + pos = var.bounds & ValueRanges(0, int_oo) + new_bounds = new_bounds | pos + + var = self.cse.generate(self.compute, stm, bounds=new_bounds) + + sympy_var = parent_handler.indirect_indexing(var, size, check) + if generate_assert(check): + assert_lower = not (var.bounds.lower >= 0) + # value ranges cannot x < s when x and s are symbols + assert_upper = not isinstance(size, sympy.Number) or not ( + var.bounds.upper < size + ) + self.check_bounds(sympy_var, size, assert_lower, assert_upper) + return sympy_var + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + return self.check_bounds(expr, size, lower, upper) + + @staticmethod + def load(name: str, index: sympy.Expr) -> CSEVariable: + if name in self.cse.invalidated_stores: + # A load from an invalidated store requires us to + # keep the actual buffer around + V.kernel.must_keep_buffers.add(name) + if free_symbol_is_type(index, SymT.TMP): + return self.indirect_load(name, index) + store_cache = self.cse.store_cache + if name in store_cache: + return store_cache[name] + out = self.load(name, index) + # count load that is not in the store_cache, and also not in the + # cse cache. + if out.use_count == 1: + self.num_load += 1 + return out + + @staticmethod + def _update_store_cache(name: str, value: CSEVariable): + self.cse.store_cache[name] = value + if self.current_node and name in V.graph.name_to_buffer: + buf = self.current_node.get_output(name) + for other_name in buf.get_mutations(): + self.cse.store_cache[other_name] = value + + @staticmethod + def store( + name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + self.store_buffer_names.add(name) + if mode is None: + CSEProxy._update_store_cache(name, value) + if name not in V.graph.removed_buffers: + return self.store(name, index, value, mode=mode) + else: + return None # type: ignore[return-value] + + @staticmethod + def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): + self.store_buffer_names.add(name) + CSEProxy._update_store_cache(name, value) + + if name not in V.graph.removed_buffers: + return self.store_reduction(name, index, value) + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, Tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: + self.num_reduction += 1 + return self.reduction(dtype, src_dtype, reduction_type, value) + + @staticmethod + def scan( + dtypes: Tuple[torch.dtype, ...], + combine_fn: Callable[ + [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], + Tuple[CSEVariable, ...], + ], + values: Tuple[CSEVariable, ...], + ) -> Tuple[CSEVariable, ...]: + return self.scan(dtypes, combine_fn, values) + + @staticmethod + def sort( + dtypes: Tuple[torch.dtype, ...], + values: Tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> Tuple[CSEVariable, ...]: + return self.sort(dtypes, values, stable, descending) + + @staticmethod + def bucketize( + values: CSEVariable, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ) -> CSEVariable: + """ + [Note: Inductor bucketize op] + + Given values (tensor) and offsets_name (reference to the name of a 1D + tensor), calculate the bucket that each value belongs to. + + e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True + return = [ 0, 1, 1, 1, 1, 3, 3, 4]. + + When right == False, bucket i refers to range (offsets[i], offsets[i+1]]. + When right == True, bucket i refers to range [offsets[i], offsets[i+1]). + + Offsets must be non-decreasing or the result is undefined. + """ + return self.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + # Use mypy to check protocol implemented correctly + def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: + return h + + super().__enter__() + assert self.overrides + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Note that V.graph.scheduler can be None when codegening triton template + kernels. + """ + if V.graph.scheduler: + V.graph.scheduler.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def rename_indexing(self, index) -> sympy.Expr: + # adds the necessary kernel args for index expressions + # and renames variables in index expressions to kernel arg names + if isinstance(index, (list, tuple)): + return [self.rename_indexing(x) for x in index] # type: ignore[return-value] + index = V.graph.sizevars.simplify(index) + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) + replacements = { + x: self.args.size(x) + for x in sorted_symbols + if symbol_is_type( + x, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ) + } + return sympy_subs(index, replacements) + + def create_cse_var(self, *args, **kwargs): + return CSEVariable(*args, **kwargs) + + +@dataclasses.dataclass +class OptimizationContext: + key: ClassVar[str] = "opt_ctx" + + dtype: Optional[torch.dtype] = None + ops_name: str = "" + + +@functools.lru_cache(None) +def jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class KernelTemplate: + """ + Base class for defining kernel templates. + + Children classes: TritonTemplate, CUDATemplate + """ + + @staticmethod + def indent_except_first(source: str, num_indents: int, indents_spacing=4): + lines = source.splitlines(True) + if len(lines) > 1: + lines[1:] = [ + (" " * indents_spacing * num_indents) + line for line in lines[1:] + ] + return "".join(lines) + + @staticmethod + def _template_from_string(source): + env = jinja2_env() + if env is not None: + env.filters["indent_except_first"] = KernelTemplate.indent_except_first + from jinja2 import TemplateSyntaxError + + class DetailedTemplateSyntaxError(TemplateSyntaxError): + def __init__(self, original_error): + super().__init__( + original_error.message, + original_error.lineno, + original_error.name, + original_error.filename, + ) + self.original_error = original_error + + def __str__(self): + error_info = f"Error in template at line {self.lineno}\n" + error_info += f"Error message: {self.message}\n" + if hasattr(self.original_error, "source"): + lines = self.original_error.source.split("\n") + error_info += "Context:\n" + start = max(0, self.lineno - 2) + end = min(len(lines), self.lineno + 2) + for i in range(start, end): + if i == self.lineno - 1: + error_info += f"{i+1}: --> {lines[i]}\n" + if hasattr(self.original_error, "column"): + error_info += ( + " " + + " " * (self.original_error.column - 1) + + "^\n" + ) + else: + error_info += f"{i+1}: {lines[i]}\n" + return error_info + + try: + return env.from_string(source) + except TemplateSyntaxError as e: + raise DetailedTemplateSyntaxError(e) from e + + return None + + @staticmethod + def _fake_get_dtype(fake_out): + _get_dtype_real = V.graph.get_dtype + + def get_dtype(name): + if name == fake_out.get_name(): + return fake_out.get_dtype() + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str): + self.name = name + + def maybe_append_choice(self, choices, **kwargs): + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choices.append(self.generate(**kwargs)) + except NotImplementedError as e: + pass + + def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller": + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError diff --git a/.venv/Lib/site-packages/torch/_inductor/codegen/cpp.py b/.venv/Lib/site-packages/torch/_inductor/codegen/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..3fecf366f65cffddf369d9e3264a2a246331cc12 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/codegen/cpp.py @@ -0,0 +1,4978 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import itertools +import math +import re +import sys +import warnings +from copy import copy, deepcopy +from enum import Enum +from typing import cast, Dict, List, Optional, Sequence, Set, Tuple, Union + +import sympy + +import torch +import torch.fx +from torch._inductor import dependencies +from torch._prims_common import is_float_dtype, is_integer_dtype +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT + +from ..._dynamo.utils import counters +from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics +from ..loop_body import LoopBody +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + ForeachKernelSchedulerNode, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ..utils import ( + cache_on_self, + get_bounds_index_expr, + get_fused_kernel_name, + has_free_symbols, + is_welford_reduction, + parallel_num_threads, + Placeholder, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, +) +from ..virtualized import NullKernelHandler, ops, OpsValue, V +from .common import ( + BackendFeature, + BracesBuffer, + CppWrapperKernelArgs, + CSE, + CSEVariable, + DataTypePropagation, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + Kernel, + KernelArgs, + OpOverrides, + OptimizationContext, +) +from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, + cexpr, + cexpr_index, + codegen_rand, + CppCSEVariable, + DTYPE_TO_CPP, + INDEX_TYPE, + LocalBufferContext, + promote_args, + unify_mask_base_type, + value_to_cpp, +) + + +_IS_WINDOWS = sys.platform == "win32" + + +def get_export_declaration(): + return "__declspec(dllexport)" if _IS_WINDOWS else "" + + +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") + +NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} +RTYPE_TO_CPP = { + "sum": "+", + "prod": "*", + "xor_sum": "^", + "min": "min", + "max": "max", + "argmin": "argmin", + "argmax": "argmax", + "any": "||", + "welford_reduce": "welford", + "welford_combine": "welford", +} +VECTORIZABLE_RTYPES = { + "max", + "min", + "sum", + "prod", + "xor_sum", + "welford_reduce", + "welford_combine", + "argmin", + "argmax", + "any", +} + +PYTHON_TO_CPP = { + "Tensor": "at::Tensor", + "int": "long", + "float": "double", + "bool": "bool", + "str": "std::string", + "ScalarType": "c10::ScalarType", + "MemoryFormat": "at::MemoryFormat", + "Layout": "at::Layout", + "Device": "at::Device", + "number": "at::Scalar", +} + +CONTAINER_PYTHON_TO_CPP = { + "List": "std::vector", + "Optional": "std::optional", +} + +DTYPE_LOWP_FP = [ + torch.bfloat16, + torch.float16, +] + +VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, +] + +MASKED_VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, +] + + +def reduction_init(reduction_type, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, the initial + # constant for reduction must be promoted as well + dtype = torch.float32 + if reduction_type in ("xor_sum", "sum", "any"): + return 0 + if reduction_type == "prod": + return 1 + if reduction_type in ("max", "argmax", "min", "argmin"): + cdtype = DTYPE_TO_CPP[dtype] + min_var = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + max_var = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + init_var = min_var if reduction_type in ("max", "argmax") else max_var + return ( + init_var + if reduction_type in ("max", "min") + else f"IndexValue<{cdtype}>{{0, {init_var}}}" + ) + if is_welford_reduction(reduction_type): + return f"Welford<{DTYPE_TO_CPP[dtype]}>()" + raise AssertionError(reduction_type) + + +def reduction_acc_type(reduction_type, dtype): + scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]] + if is_welford_reduction(reduction_type): + return f"Welford<{scalar_type}>" + if reduction_type in {"argmin", "argmax"}: + return f"IndexValue<{scalar_type}>" + return scalar_type + + +def reduction_combine( + reduction_type, + var, + next_value, + index: Optional[sympy.Symbol] = None, + src_dtype=None, +): + is_bool = src_dtype == torch.bool + if reduction_type == "sum": + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + if reduction_type == "prod": + return f"{var} * {next_value}" + if reduction_type == "xor_sum": + return f"{var} ^ {next_value}" + if reduction_type == "any": + return f"{var} || {next_value}" + if reduction_type in ("min", "max"): + return f"{reduction_type}_propagate_nan({var}, {next_value})" + if reduction_type == "welford_reduce": + return f"welford_combine({var}, {next_value})" + if reduction_type == "welford_combine": + if isinstance(next_value, tuple): + mean, m2, weight = next_value + else: + mean, m2, weight = reduction_project(reduction_type, next_value) + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + if reduction_type in ("argmin", "argmax"): + if index is not None: + return f"{reduction_type}_combine({var}, {next_value}, {index})" + else: + return f"{reduction_type}_combine({var}, {next_value})" + raise AssertionError(reduction_type) + + +def reduction_project(reduction_type, acc): + if is_welford_reduction(reduction_type): + return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight" + elif reduction_type in {"argmin", "argmax"}: + return f"{acc}.index" + return acc + + +@functools.lru_cache +def stride_at(index: sympy.Expr, var: sympy.Symbol): + if not index.has(var): + # see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu + # which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation. + # in this case, there is no dependencies between index and var. + return sympy.Integer(0) + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) # type: ignore[arg-type] + return sympy.simplify(new_index - index) + + +@functools.lru_cache +def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int): + """ + Simplifies the index expression within the range of a vectorized loop. + Given a vectorized loop variable `var` in the range of a loop with `vec_length`, + this function transforms the `index` into an equivalent form. It handles + simplifications for cases where `var` can be expressed as `vec_length * a + b`, + where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences + of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations. + + NOTE: + The simplified index expression is intended for analysis purposes only, not + for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables + which are not dependent on the loop variable `var` in the vectorized range. Check + https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details. + + Examples: + 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then + `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable + when `div` is divisible by 16. + 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free + variable when `mod` is divisible by 16. + """ + + div_freevar_id = 0 + mod_freevar_id = 0 + + def visit_indexing_div(divisor): + nonlocal div_freevar_id + result = FloorDiv(var, divisor) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_div_c{div_freevar_id}") + div_freevar_id += 1 + return result + + def visit_modular_indexing(divisor, modulus): + nonlocal mod_freevar_id + result = ModularIndexing(var, divisor, modulus) + if sympy.gcd(divisor, vec_length) == vec_length: + result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length: + result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}") + mod_freevar_id += 1 + return result + + original_index = index + + div = sympy.Wild("divisor", integer=True) + if index.has(FloorDiv): + index = index.replace(FloorDiv(var, div), visit_indexing_div) + + mod = sympy.Wild("modulus", integer=True) + if index.has(ModularIndexing): + index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) + + index = sympy.simplify(index) + if index != original_index: + return simplify_index_in_vec_range(index, var, vec_length) + + return index + + +@functools.lru_cache +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) + + +class OuterLoopFusedSchedulerNode(FusedSchedulerNode): + @classmethod + def fuse( # type: ignore[override] + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth + ): + assert node1.scheduler is node2.scheduler + assert all( + type(node) + in ( + OuterLoopFusedSchedulerNode, + SchedulerNode, + FusedSchedulerNode, + ) + for node in (node1, node2) + ) + if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return cls( + node1.scheduler, + ( + list(node1.get_outer_nodes()) + if type(node1) is OuterLoopFusedSchedulerNode + else [ + node1, + ] + ) + + ( + list(node2.get_outer_nodes()) + if type(node2) is OuterLoopFusedSchedulerNode + else [ + node2, + ] + ), + outer_loop_fusion_depth, + ) + else: + return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item] + + def __init__( + self, + scheduler: "Scheduler", + outer_fused_nodes: List[Union[FusedSchedulerNode, SchedulerNode]], + outer_loop_fusion_depth, + ): + self.outer_fused_nodes: List[ + Union[FusedSchedulerNode, SchedulerNode] + ] = outer_fused_nodes + self.outer_loop_fusion_depth = outer_loop_fusion_depth + flatten_snodes = [] + for _node in self.outer_fused_nodes: + assert isinstance(_node, (SchedulerNode, FusedSchedulerNode)) + flatten_snodes.extend(list(_node.get_nodes())) + super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type] + + def get_outer_nodes(self): + return self.outer_fused_nodes + + def check_outer_fusion_loop_level_attr( + self, cpp_kernel_proxy_list, outer_loop_fusion_depth + ): + # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth. + # In the fusion stage, we only examine nodes with same vars and reduce. + # However, for nodes with same vars and reduce, the loops may still have different tile splits. + # For example (test_expr_vec_non_contiguous in test_cpu_repro.py): + # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level. + # If the check failed, we should fall back to standard loop codegen. + def _inner( + left_loop_level: LoopLevel, + right_loop_level: LoopLevel, + loop_fusion_depth: int, + ) -> bool: + # Check if same loop level attr + outer_loops_attr_compare_list = [ + "var", + "size", + "offset", + "steps", + ] + if not ( + all( + getattr(left_loop_level, attr_compare) + == getattr(right_loop_level, attr_compare) + for attr_compare in outer_loops_attr_compare_list + ) + ): + return False + + assert loop_fusion_depth >= 1 + if (loop_fusion_depth := loop_fusion_depth - 1) > 0: + # If the next loop level is expected to undergo outer loop fusion, + # there should be no kernel present at the current loop level. + assert ( + left_loop_level.kernel is None and right_loop_level.kernel is None + ) + # Check next loop level attr + if any( + # Assume no main/tail loop split at any outer loop fusion depth + # Given no clear performance benefit for this complex case + len(loop_level.inner) != 1 + for loop_level in [left_loop_level, right_loop_level] + ) or not _inner( + left_loop_level.inner[0], + right_loop_level.inner[0], + loop_fusion_depth, + ): + return False + + return True + + for idx in range(len(cpp_kernel_proxy_list) - 1): + left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest + right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest + if any( + # Assume no main/tail loop split at any outer loop fusion depth + len(loop_nest.root) != 1 + for loop_nest in [left_loop_nest, right_loop_nest] + ) or not _inner( + left_loop_nest.root[0], right_loop_nest.root[0], outer_loop_fusion_depth + ): + return False + + return True + + def merge_outer_fusion_kernels( + self, + cpp_kernel_proxy_list, + ): + loop_nest_list: List[LoopNestWithSplit] = [ + kernel.loop_nest for kernel in cpp_kernel_proxy_list + ] + kernel_group = cpp_kernel_proxy_list[0].kernel_group + + def _merge_outer_fusion_loop_levels( + loop_level_nested_list: List[List["LoopLevel"]], + outer_loop_fusion_depth, + ): + assert outer_loop_fusion_depth >= 1 + # Assume no main/tail loop split at any outer loop fusion depth + assert all( + len(loop_level_list) == 1 for loop_level_list in loop_level_nested_list + ) + if (outer_loop_fusion_depth := outer_loop_fusion_depth - 1) >= 1: + # Further merge the next loop level + next_loop_level_nested_list = [ + loop_level_list[0].inner + for loop_level_list in loop_level_nested_list + ] + _merge_outer_fusion_loop_levels( + next_loop_level_nested_list, + outer_loop_fusion_depth, + ) + else: + outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group) + loop_level_of_first_kernel = loop_level_nested_list[0][0] + for kernel_idx in range(len(loop_level_nested_list)): + outer_loop_fused_kernel.inner.append( + deepcopy(loop_level_nested_list[kernel_idx][0]), + ) + loop_level_of_first_kernel.inner = [] + loop_level_of_first_kernel.kernel = outer_loop_fused_kernel + + # Merge the List[LoopNestWithSplit] from cpp_kernel_proxy_list + # into cpp_kernel_proxy_list[0].loop_nest + _merge_outer_fusion_loop_levels( + [_loop_nest.root for _loop_nest in loop_nest_list], # type: ignore[misc] + self.outer_loop_fusion_depth, + ) + return cpp_kernel_proxy_list[0] + + +class RecordOptimizationContext: + def __init__(self, func_name: str = ""): + self.func_name = func_name + self.current_node: Optional[torch.fx.Node] = None + self.opt_ctx: Optional[OptimizationContext] = None + + def __enter__(self): + assert V.interpreter + assert V.interpreter.current_node + + self.current_node = V.interpreter.current_node + assert self.current_node is not None + if OptimizationContext.key in self.current_node.meta: + self.opt_ctx = self.current_node.meta[OptimizationContext.key] + else: + self.opt_ctx = OptimizationContext() + assert self.opt_ctx is not None + self.opt_ctx.ops_name = self.func_name + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self.current_node + assert self.opt_ctx + self.current_node.meta[OptimizationContext.key] = self.opt_ctx + + def get_opt_ctx(self): + return self.opt_ctx + + def get_fx_node(self): + assert self.current_node + return self.current_node + + +class CppOverrides(OpOverrides): + """Map element-wise ops to C++""" + + @staticmethod + def add(a, b): + return f"decltype({a})({a} + {b})" + + @staticmethod + def sub(a, b): + return f"decltype({a})({a} - {b})" + + @staticmethod + def mul(a, b): + return f"decltype({a})({a} * {b})" + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): + assert isinstance(x, CppCSEVariable) + if src_dtype is None: + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float: + """ + https://github.com/pytorch/pytorch/issues/115260 + For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is + in low-precision floating point data type. When the output of node1 also serves as the output of the + kernel, the result of nodes would be different from the case when output of node1 is not the output + of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on + storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type + to the cse cache. + + Example (pseudo code): + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = load(buf) + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + Without cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = to_dtype(node2_input_lowp, dtype=torch.float) + + With cse cache trick: + node1_output = ... + node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16) + # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache + store(buf, node1_output_lowp) + node2_input_lowp = node_output_lowp # hit store cache + node2_input = node1_output # hit cse cache + """ + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def to_dtype_bitcast(x, dtype, src_dtype): + assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP" + if src_dtype in (torch.float16, torch.bfloat16): + # c10::bit_cast requires the source and target have the bitwidth. + # Because the input tensor's dtype could be promoted, e.g. from float16 to + # float, we have to cast the tensor to its original source dtype before + # invoking bit_cast. We also need to convert the bit-casted tensor + # back to float to make sure we keep using higher precision values + # for the rest of the computation. + cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})" + cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})" + return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})" + else: + return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" + + @staticmethod + def abs(x): + return f"std::abs({x})" + + @staticmethod + def sin(x): + return f"std::sin({x})" + + @staticmethod + def cos(x): + return f"std::cos({x})" + + @staticmethod + def neg(x): + return f"decltype({x})(-{x})" + + @staticmethod + def exp(x): + # return f"Sleef_expf_u10({x})" + return f"std::exp({x})" + + @staticmethod + def exp2(x): + return f"std::exp2({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + + @staticmethod + def erf(x): + return f"std::erf({x})" + + @staticmethod + def erfc(x): + return f"std::erfc({x})" + + @staticmethod + def erfinv(x): + return f"calc_erfinv({x})" + + @staticmethod + def sqrt(x): + return f"std::sqrt({x})" + + @staticmethod + def rsqrt(x): + return f"1 / std::sqrt({x})" + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::log1p({x})" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def tan(x): + return f"std::tan({x})" + + @staticmethod + def tanh(x): + return f"std::tanh({x})" + + @staticmethod + def signbit(x): + """ + On windows std::signbit only support float type. + Ref: https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/signbit?view=msvc-170 + """ + return ( + f"std::signbit(static_cast({x}))" + if _IS_WINDOWS + else f"std::signbit({x})" + ) + + @staticmethod + def pow(a, b): + return f"std::pow({a}, {b})" + + @staticmethod + def log(x): + return f"std::log({x})" + + @staticmethod + def round(x): + return f"std::nearbyint({x})" + + @staticmethod + def floor(x): + return f"std::floor({x})" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" + + @staticmethod + def ceil(x): + return f"std::ceil({x})" + + @staticmethod + def trunc(x): + return f"std::trunc({x})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def fmod(a, b): + return f"std::fmod({a}, {b})" + + @staticmethod + def isinf(x): + return f"std::isinf({x})" + + @staticmethod + def isnan(x): + return f"std::isnan({x})" + + @staticmethod + def lgamma(x): + return f"std::lgamma({x})" + + @staticmethod + def acos(x): + return f"std::acos({x})" + + @staticmethod + def acosh(x): + return f"std::acosh({x})" + + @staticmethod + def cosh(x): + return f"std::cosh({x})" + + @staticmethod + def sinh(x): + return f"std::sinh({x})" + + @staticmethod + def asin(x): + return f"std::asin({x})" + + @staticmethod + def asinh(x): + return f"std::asinh({x})" + + @staticmethod + def atan2(x, y): + return f"std::atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"std::atan({x})" + + @staticmethod + def atanh(x): + return f"std::atanh({x})" + + @staticmethod + def copysign(x, y): + return f"std::copysign({x}, {y})" + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + code.writeline(f"int32_t {exponent};") + code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @staticmethod + def hypot(x, y): + return f"std::hypot({x}, {y})" + + @staticmethod + def log10(x): + return f"std::log10({x})" + + @staticmethod + def log2(x): + return f"std::log2({x})" + + @staticmethod + def nextafter(x, y): + return f"std::nextafter({x}, {y})" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"std::max({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"min_propagate_nan({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"max_propagate_nan({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"{a} ? {b} : {c}" + + @staticmethod + def mod(a, b): + return f"mod({a}, {b})" + + @staticmethod + def constant(val, dtype): + if dtype in DTYPE_LOWP_FP: + # Since load promotes all half-precision inputs to float, constants + # must be promoted as well + dtype = torch.float32 + return value_to_cpp(val, DTYPE_TO_CPP[dtype]) + + @staticmethod + def index_expr(expr, dtype): + idx_str = cexpr(V.kernel.rename_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask, body, other): + code = BracesBuffer() + + # Write masked operation into a lambda + body_var = V.kernel.cse.newvar() + code.writeline(f"auto {body_var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + # Use the lambda's return type as the type of other + other_code = value_to_cpp(other, f"decltype({body_var}())") + return f"{mask} ? {body_var}() : {other_code}" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_not(a): + return f"!{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} != {b}" + + @staticmethod + def bitwise_and(a, b): + return f"decltype({a})({a} & {b})" + + @staticmethod + def bitwise_not(a): + return f"decltype({a})(~{a})" + + @staticmethod + def bitwise_or(a, b): + return f"decltype({a})({a} | {b})" + + @staticmethod + def bitwise_xor(a, b): + return f"decltype({a})({a} ^ {b})" + + @staticmethod + def bitwise_left_shift(a, b): + return f"decltype({a})({a} << {b})" + + @staticmethod + def bitwise_right_shift(a, b): + return f"decltype({a})({a} >> {b})" + + @staticmethod + def rand(seed: sympy.Expr, offset: sympy.Expr): + return f"normalized_rand_cpu({seed}, {offset})" + + @staticmethod + def randn(seed: sympy.Expr, offset: sympy.Expr): + return f"randn_cpu({seed}, {offset})" + + @staticmethod + def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high): + return f"randint64_cpu({seed}, {offset}, {low}, {high})" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" + + @staticmethod + def sign(x): + code = BracesBuffer() + scalar_zero = f"decltype({x})(0)" + scalar_one = f"decltype({x})(1)" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};") + code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};") + code.writeline("return left - right;") + code.writeline("()") + return code + + +CppOverrides._initialize_pointwise_overrides("cpp") + + +class CppVecOverrides(CppOverrides): + """Map element-wise ops to aten vectorization C++""" + + def __new__(cls, *args, **kargs): + self = super().__new__(cls) + + def wrap(func): + # `CppVecKernel` generates both scalar ops and vector ops according to + # whether the inputs are scalars or vectors while all ops in `CppVecOverrides` + # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in + # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to + # `CppOverrides` when all inputs are scalars. + # + # Notes on ops handled separately in their own functions: + # `ops.masked`: + # needs recursive handling of masked body. + # `ops.index_expr`: + # needs to further analyze the dependency of the index expression on + # the tiling itervar. + def wrapper(*args, **kwargs): + scalars = [ + arg + for arg in args + if isinstance(arg, (int, sympy.Expr)) + or (isinstance(arg, CppCSEVariable) and not arg.is_vec) + ] + vectors = [ + arg + for arg in args + if isinstance(arg, CppCSEVariable) and arg.is_vec + ] + new_args = list(args) + if scalars and vectors: + new_args = [] + for arg in args: + if isinstance(arg, (int, sympy.Expr)): + if isinstance(arg, sympy.Expr) and not arg.is_number: + arg = ops.index_expr(arg, torch.int64) + else: + arg = ops.constant(arg, torch.int64) + arg = arg.value if isinstance(arg, OpsValue) else arg + new_args.append(arg) + + # DType Promotion + if vectors: + # We have saw several data type mismatch issues related with index_expr in + # the lowering phase of torch.int8. torch.int32, torch.int64. + # 1. int32 and int64 in test_torchinductor.py::test_max_pool2d_with_indices_backward3_cpu + # 2. int8 and int32 in test_torchinductor.py::test_max_pool2d5_cpu + # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu + if len(new_args) == 2: + new_args = promote_args(new_args) + elif func == CppVecOverrides.where: + new_args[1:] = promote_args(new_args[1:]) + + # Broadcast scalar args to vector + if scalars and vectors: + assert isinstance(V.kernel, CppVecKernel) + new_args = [ + V.kernel.broadcast(new_arg) + if ( + isinstance(new_arg, CppCSEVariable) + and not new_arg.is_vec + and func + not in [ + CppVecOverrides.rand, + CppVecOverrides.randn, + CppVecOverrides.randint64, + ] + ) + else new_arg + for new_arg in new_args + ] + + if vectors: + return func(*new_args, **kwargs) + else: + # fallback to scalar ops + scalar_ops = super(CppVecOverrides, self) + scalar_func = getattr( + scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] + ) + assert scalar_func is not None + return scalar_func(*args, **kwargs) + + return wrapper + + for name, method in vars(CppVecOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in [ + "masked", + "index_expr", + ]: + setattr(self, name, wrap(method.__func__)) + + return self + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def exp2(x): + return f"{x}.exp2()" + + @staticmethod + def expm1(x): + # decompose for a better performance + vec_one = f"decltype({x})(1)" + return f"{x}.exp() - {vec_one}" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def erfc(x): + return f"{x}.erfc()" + + @staticmethod + def erfinv(x): + return f"{x}.erfinv()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def eq(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})" + + @staticmethod + def ne(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + if x.dtype == torch.bool: + assert y.dtype == torch.bool + x_cast, y_cast = unify_mask_base_type(V.kernel.compute, (x, y)) + return f"{x_cast} != {y_cast}" + else: + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})" + + @staticmethod + def lt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})" + + @staticmethod + def gt(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})" + + @staticmethod + def le(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})" + + @staticmethod + def ge(x, y): + assert isinstance(V.kernel, CppVecKernel) + assert isinstance(x, CppCSEVariable) + assert x.dtype is not None + return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})" + + @staticmethod + def and_(x, y): + return f"{x} & {y}" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"~{a}" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def load_seed(name, offset): + assert isinstance(V.kernel, CppVecKernel) + return f"{V.kernel.load(name, offset)}" + + @staticmethod + def rand(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = ( + f"result[offset_idx] = normalized_rand_cpu({seed}, offset[offset_idx]);" + ) + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randn(seed, offset): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randn_cpu({seed}, offset[offset_idx]);" + return codegen_rand(offset, code, rand_function) + + @staticmethod + def randint64(seed, offset, low, high): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + rand_function = f"result[offset_idx] = randint64_cpu({seed}, offset[offset_idx], {low}, {high});" + return codegen_rand(offset, code, rand_function, torch.int64) + + @staticmethod + def remainder(a, b): + assert ( + a.dtype == b.dtype + ), "remainder vec implementation expect the same inputs' dtype." + return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" + + @staticmethod + def tan(a): + return f"{a}.tan()" + + @staticmethod + def tanh(a): + vec_one = f"decltype({a})(1)" + vec_two = f"decltype({a})(2)" + vec_minus_two = f"decltype({a})(-2)" + return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def atan(x): + return f"{x}.atan()" + + @staticmethod + def acos(x): + return f"{x}.acos()" + + @staticmethod + def asin(x): + return f"{x}.asin()" + + @staticmethod + def cosh(x): + return f"{x}.cosh()" + + @staticmethod + def sinh(x): + return f"{x}.sinh()" + + @staticmethod + def log10(x): + return f"{x}.log10()" + + @staticmethod + def log2(x): + return f"{x}.log2()" + + @staticmethod + def nextafter(x, y): + return f"{x}.nextafter({y})" + + @staticmethod + def copysign(a, b): + return f"{a}.copysign({b})" + + @staticmethod + def atan2(a, b): + return f"{a}.atan2({b})" + + @staticmethod + def hypot(a, b): + return f"{a}.hypot({b})" + + @staticmethod + def atanh(x): + # For real x, atanh(x) = 1/2 * log((1+x)/(1-x)) + vec_one = f"decltype({x})(1)" + vec_one_half = f"decltype({x})(0.5)" + return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()" + + @staticmethod + def asinh(x): + # For real x, asinh(x) = log(x + sqrt(1 + x**2)) + vec_one = f"decltype({x})(1)" + return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()" + + @staticmethod + def acosh(x): + return f"{x}.acosh()" + + @staticmethod + def relu(x): + bug = config.cpp.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + return f"{x}; throw 1" + elif bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"at::vec::clamp_min({x}, decltype({x})(0))" + else: + raise AssertionError( + f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + # TODO: this seems to be dead + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + if is_float_dtype(a.dtype): + assert ( + a.dtype == b.dtype + ), "div_floor_floating_vec implementation expect the same inputs' dtype." + return f"div_floor_floating_vec({a}, {b})" + else: + assert all(is_integer_dtype(item.dtype) for item in [a, b]) + # a and b are integer type + _t = f"decltype({a})" + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + quot = f"{a} / {b}" + has_rem = f"({a} % {b} != {_t}(0))" + is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))" + return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + # Doing blend to set the remaining bits of b to non-zero + _t = f"decltype({b})" + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} & {b_cast}" + else: + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + if a.dtype == torch.bool: + assert b.dtype == torch.bool + a_cast, b_cast = unify_mask_base_type(V.kernel.compute, (a, b)) + return f"{a_cast} | {b_cast}" + else: + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a} * {a}" + + @staticmethod + def where(a, b, c): + assert isinstance(V.kernel, CppVecKernel) + if b.dtype == torch.bool: + assert c.dtype == torch.bool + blendv_a, blendv_b, blendv_c = unify_mask_base_type( + V.kernel.compute, (a, b, c) + ) + return f"decltype({blendv_b})::blendv({blendv_c}, {blendv_b}, {blendv_a})" + else: + return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})" + + @staticmethod + def sign(x): + code = BracesBuffer() + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + code.writeline("[&]()") + with code.indent(): + code.writeline(f"auto left = {blendv_l};") + code.writeline(f"auto right = {blendv_r};") + code.writeline("return left - right;") + code.writeline("()") + return code + + @staticmethod + def to_dtype(x, dtype, src_dtype=None, use_compute_dtypes=True): + assert dtype in [ + torch.bool, + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, + ], f"{__name__} does not support {dtype}" + assert isinstance(x, CppCSEVariable) + src_dtype = x.dtype + expr = V.kernel.get_to_dtype_expr(x, dtype, src_dtype) + csevar = V.kernel.cse.generate(V.kernel.compute, expr) + csevar.update_on_args("to_dtype", (x, dtype), {"src_dtype": src_dtype}) + if dtype in [torch.bfloat16, torch.float16] and src_dtype == torch.float: + V.kernel.cache_dtype_convert(x, src_dtype, csevar, dtype) + return csevar + + @staticmethod + def log1p(x): + bug = config.cpp.inject_log1p_bug_TESTING_ONLY + if bug == "accuracy": + return f"{x} + decltype({x})(1)" + elif bug is None: + return f"{x}.log1p()" + else: + raise AssertionError( + f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def masked(mask, body, other): + assert isinstance(V.kernel, CppVecKernel) + code = BracesBuffer() + var = V.kernel.cse.newvar() + with V.kernel.masked(mask) as new_mask: + code.writeline(f"auto {var} = [&]") + with V.kernel.swap_buffers(code), code.indent(): + result = body() + code.writeline(f"return {result};") + code.writeline(";") + V.kernel.compute.splice(code) + + dtype = result.dtype + body_code = f"{var}()" + body_code_vec = ( + body_code + if result.is_vec + else f"{V.kernel._get_vec_type(dtype)}({body_code})" + ) + other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) + # loading bool as VecMask + other_code_vec = ( + f"{V.kernel._get_mask_type()}::from({other_code})" + if dtype == torch.bool + else f"{V.kernel._get_vec_type(dtype)}({other_code})" + ) + assert isinstance(new_mask, CppCSEVariable), new_mask + if new_mask.is_vec: + code = BracesBuffer() + code.writeline("[&]") + with V.kernel.swap_buffers(code), code.indent(): + code.writeline(f"if ({new_mask}.all_zero())") + with code.indent(): + code.writeline(f"return {other_code_vec};") + code.writeline("else") + with code.indent(): + # Create cse variable to reuse kernel.overrides.where + body_vec_var = V.kernel.cse.generate( + V.kernel.compute, + body_code_vec, + ) + other_vec_var = V.kernel.cse.generate( + V.kernel.compute, + other_code_vec, + ) + assert isinstance(body_vec_var, CppCSEVariable), body_vec_var + assert isinstance(other_vec_var, CppCSEVariable), other_vec_var + body_vec_var.dtype = dtype + other_vec_var.dtype = dtype + code.writeline( + f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};" + ) + code.writeline("()") + csevar = V.kernel.cse.generate( + V.kernel.compute, + code, + ) + elif result.is_vec: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}" + ) + else: + csevar = V.kernel.cse.generate( + V.kernel.compute, f"{mask} ? {body_code} : {other_code}" + ) + # `result` is explicitly added to the args for correct propagation + # of relevant itervars and vectorization status. + csevar.update_on_args("masked", (mask, body, other, result), {}) + return csevar + + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppVecKernel) + index = V.kernel.rename_indexing(expr) + tiling_var = V.kernel.itervars[V.kernel.tiling_idx] + stride = V.kernel._try_get_const_stride(index, tiling_var) + if stride == 0: + return CppOverrides.index_expr(expr, dtype) + elif stride is not None: + idx = V.kernel.cse.generate( + V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr) + ) + value = ops.to_dtype(idx, dtype) + if isinstance(value, OpsValue): + value = value.value + csevar = V.kernel.arange(value, stride) + else: + csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment] + None, index, dtype, V.kernel.compute + ) + csevar.update_on_args("index_expr", (expr, dtype), {}) + return csevar + + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" + ) + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @classmethod + def scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" + ) + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + if output_mask: + assert not kernel.tail_size + load_args = "tmpbuf_out.data()" + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + else: + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" + if n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + for name, method in vars(CppOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in vars( + CppVecOverrides + ): + func = cls.scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + + +CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() + + +class CppTile2DOverrides(CppVecOverrides): + @staticmethod + def index_expr(expr, dtype): + assert isinstance(V.kernel, CppTile2DKernel) + expr = V.kernel.transform_indexing(expr) + return CppVecOverrides.index_expr(expr, dtype) + + +class CppKernel(Kernel): + overrides = CppOverrides # type: ignore[assignment] + sexpr = cexpr + newvar_prefix = "auto " + suffix = ";" + + def __init__(self, args, num_threads): + super().__init__(args) + self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None + self.ranges: List[sympy.Expr] = [] + self.itervars: List[sympy.Symbol] = [] + self.reduction_depth = None + self.reduction_prefix = IndentedBuffer() + self.reduction_suffix = IndentedBuffer() + self.parallel_reduction_prefix = IndentedBuffer() + self.parallel_reduction_suffix = IndentedBuffer() + self.local_reduction_init = IndentedBuffer() + self.local_reduction_stores = IndentedBuffer() + self.is_reduction = False + self.non_parallel_reduction_prefix = IndentedBuffer() + self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") + self.weight_recps_cse = CSE( + self.newvar_prefix, self.suffix, name_prefix="wrecps" + ) + self.preloads = IndentedBuffer() + self.poststores = IndentedBuffer() + self.num_threads = num_threads # num_threads the kernel specialized for + self.reduction_omp_dec: Dict[Tuple[str, str], str] = {} + + def _gen_parallel_reduction_buffers( + self, + acc, + acc_type, + reduction_type, + dtype, + reduction_combine_fn=reduction_combine, + reduction_init_fn=reduction_init, + welford_weight_reciprocal_vec_fn=None, + ): + if config.cpp.dynamic_threads and not self.parallel_reduction_prefix: + self.parallel_reduction_prefix.writeline( + "int max_threads = omp_get_max_threads();" + ) + acc_local = f"{acc}_local" + num_threads = ( + "max_threads" if config.cpp.dynamic_threads else parallel_num_threads() + ) + acc_per_thread_var_name = f"{acc}_arr" + acc_per_thread = f"{acc_per_thread_var_name}[{num_threads}]" + """ + MSVC don't support dynamic array(VLA). Please use std::unique_ptr to instead of it. + Ref: https://stackoverflow.com/questions/56555406/creating-dynamic-sized-array-using-msvc-c-compiler + MSVC is the only one compiler, which not support VLA. And MSVC can't get good inductor performance. + So, we can use unique_ptr make it works on MSVC. + For other compilers, we continue to use VLA to get best performence. + """ + acc_per_thread_unique_ptr_decl = f"auto {acc_per_thread_var_name} = std::make_unique<{acc_type}[]>({num_threads})" + acc_per_thread_vla_decl = f"{acc_per_thread_var_name}[{num_threads}]" + acc_local_in_array = acc_per_thread.replace(f"[{num_threads}]", "[tid]") + self.local_reduction_init.writeline( + f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};" + ) + self.parallel_reduction_prefix.writeline( + f"{acc_per_thread_unique_ptr_decl};" + if cpp_builder.is_msvc_cl() + else f"{acc_type} {acc_per_thread_vla_decl};" + ) + self.parallel_reduction_prefix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc_local_in_array} = {reduction_init_fn(reduction_type, dtype)};", + "}", + ], + ) + self.local_reduction_stores.writelines( + [ + f"{acc_local_in_array} = {acc_local};", + ] + ) + self.parallel_reduction_suffix.writelines( + [ + f"for (int tid = 0; tid < {num_threads}; tid++)", + "{", + f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array, src_dtype=dtype)};", + "}", + ], + ) + + def get_reduction_var_pattern(self, line: str): + return re.search("tmp_acc[0-9]+", line) + + def update_stores_with_parallel_reduction(self): + for i, line in enumerate(self.stores._lines): + if isinstance(line, str): + m = self.get_reduction_var_pattern(line) + if m: + var_name = m.group(0) + self.stores._lines[i] = line.replace(var_name, f"{var_name}_local") + + @contextlib.contextmanager + def masked(self, mask): + """Context manager to add an additional mask to loads and stores.""" + prior = self._load_mask + if prior: + mask = ops.and_(mask, prior) + if isinstance(mask, OpsValue): + mask = mask.value + assert isinstance(mask, CppCSEVariable) + # see NOTE [dtype of CppCSEVariable] + # mask's dtype should be bool + mask.dtype = torch.bool + + self._load_mask = mask + try: + yield mask + finally: + self._load_mask = prior + + def scale_index_with_offset( + self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0 + ): + var = self.itervars[itervar_idx] + replacement = {var: var * scale + offset} + new_index = sympy_subs(index, replacement) + return new_index + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in cpp code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel. + """ + return cexpr(self.rename_indexing(index)) + + def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + """ + Check if an index has free symbol CppCSEVariable that depends on `itervar`. + """ + return any( + self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined] + for s in index.free_symbols + if s.name in self.cse.varname_map # type: ignore[attr-defined] + and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined] + ) + + def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol): + return itervar in index.free_symbols or self.index_indirect_depends_on( + index, itervar + ) + + def var_ranges(self): + return dict(zip(self.itervars, self.ranges)) + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + indirect = free_symbol_is_type(expr, SymT.TMP) + if indirect: + # indexing in compute + csevar = ops.index_expr(expr, torch.int64).value + buffer = V.kernel.compute + else: + # indexing in loads + prior_compute = V.kernel.compute + try: + V.kernel.compute = self.loads + csevar = ops.index_expr(expr, torch.int64).value + finally: + V.kernel.compute = prior_compute + buffer = self.loads + + size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None + + line = self.indirect_assert( + csevar, "0" if lower else None, size_str, self._load_mask + ) + self.cse.generate(buffer, line, assignment=False) + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + line = f"{var}[{cexpr_index(index)}]" + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (self, name, index), {}) + return csevar + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + index = self.rename_indexing(index) + if mode is None: + line = f"{var}[{cexpr_index(index)}] = {value};" + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + line = f"{var}[{cexpr_index(index)}] += {value};" + else: + dtype = V.graph.get_dtype(name) + # mirroring static_cast(...) in load: + value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})" + line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});" + else: + raise NotImplementedError(f"store mode={mode}") + self.stores.writeline(DeferredLine(name, line)) + + def reduction(self, dtype, src_dtype, reduction_type, value): + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + self.is_reduction = True + init_dtype = src_dtype if argmax_or_argmin else dtype + acc_type = reduction_acc_type(reduction_type, init_dtype) + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};" + ) + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + self.stores.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, value, index)};" + ) + self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, init_dtype) + result = reduction_project(reduction_type, acc) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + self.reduction_suffix.writeline( + DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};") + ) + + def set_ranges(self, lengths, reduction_lengths): + if self.call_ranges: + assert self.call_ranges == tuple(lengths) + tuple( + reduction_lengths + ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" + assert self.reduction_depth == len(lengths) + else: + self.call_ranges = tuple(lengths) + tuple(reduction_lengths) + self.ranges = [self.rename_indexing(x) for x in self.call_ranges] + self.itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(self.ranges)) + ] + self.reduction_depth = len(lengths) + return ( + self.itervars[: self.reduction_depth], + self.itervars[self.reduction_depth :], + ) + + def size_hint(self): + return V.graph.sizevars.size_hint( + sympy_product(self.call_ranges), fallback=8192 + ) + + def codegen_loops_impl(self, loop_nest, code, worksharing): + threads = parallel_num_threads() + assert self.call_ranges is not None + kernels = loop_nest.get_kernels() + has_outer_loop_kernel = any( + isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels + ) + if has_outer_loop_kernel: + assert len(kernels) == 1 + assert isinstance(kernels[0], OuterLoopFusedKernel) + par_depth = kernels[0].decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + else: + par_depth = self.decide_parallel_depth( + loop_nest.max_parallel_depth(), threads + ) + + with contextlib.ExitStack() as stack: + if par_depth: + if loop_nest.is_reduction_only(): + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + else: + worksharing.parallel(threads) + loop_nest.mark_parallel(par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + def gen_loop_kernel(loop: LoopLevel): + def is_parallel_reduction(loop): + root = loop.get_root() + return root.is_reduction and root.parallel + + kernels = loop.get_kernels() + assert len(kernels) == 1 + if not isinstance( + kernels[0], OuterLoopFusedKernel + ) and is_parallel_reduction(loop): + kernels[0].update_stores_with_parallel_reduction() + gen_kernel(kernels[0]) + + def gen_kernel(kernel): + if isinstance(kernel, OuterLoopFusedKernel): + for loop in kernel.inner: + if loop.inner: + gen_loops(loop.inner, loop.is_reduction) + else: + with contextlib.ExitStack() as stack: + # If there is any kernel existing at the final outer loop fusion level, + # the kernel code should be placed within its respective indent to prevent + # the duplication of variable definitions. + stack.enter_context(code.indent()) + gen_loop_kernel(loop) + else: + with contextlib.ExitStack() as stack: + assert kernel + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.preloads) + kernel.codegen_inner_loops(code) + stack.enter_context(code.indent()) + code.splice(kernel.loads) + code.splice(kernel.compute) + code.splice(kernel.stores) + if hasattr(kernel, "codegen_inner_loops"): + code.splice(kernel.poststores) + + def get_reduction_code_buffer(loops, buffer="prefix"): + assert buffer in ("prefix", "suffix", "local") + for loop in loops: + for kernel in loop.get_kernels(): + if buffer == "local": + return ( + kernel.local_reduction_init, + kernel.local_reduction_stores, + ) + elif buffer == "suffix": + suffix = kernel.reduction_suffix + if loop.parallel: + suffix = kernel.parallel_reduction_suffix + suffix + return suffix + else: + prefix = kernel.reduction_prefix + if loop.parallel: + prefix = prefix + kernel.parallel_reduction_prefix + else: + prefix = prefix + kernel.non_parallel_reduction_prefix + return prefix + + def gen_loops(loops: List[LoopLevel], in_reduction=False): + with contextlib.ExitStack() as stack_outer: + local_reduction_init = local_reduction_stores = None + if loops: + loop = loops[0] + if loop.is_reduction and not in_reduction: + reduction_prefix = get_reduction_code_buffer(loops) + if reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(reduction_prefix) + if loop_nest.is_reduction_only() and loop.parallel: + ( + local_reduction_init, + local_reduction_stores, + ) = get_reduction_code_buffer(loops, "local") + worksharing.parallel(threads) + if local_reduction_init: + assert local_reduction_stores + code.splice(local_reduction_init) + + for loop in loops: + gen_loop(loop) + + if loops: + loop = loops[0] + if loop_nest.is_reduction_only() and loop.parallel: + if local_reduction_stores: + code.splice(local_reduction_stores) + worksharing.close() + if loop.is_reduction and not in_reduction: + code.splice(get_reduction_code_buffer(loops, "suffix")) + + def gen_loop(loop: LoopLevel): + with contextlib.ExitStack() as stack: + loop_lines = loop.lines() + if loop_lines is None: + return + code.writelines(loop_lines) + stack.enter_context(code.indent()) + # generate inner loops or loop body + if loop.inner: + gen_loops(loop.inner, loop.is_reduction) + else: + gen_loop_kernel(loop) + + stack.enter_context(code.indent()) + if loop_nest.root: + if ( + has_outer_loop_kernel + and isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + # Allocate local buffer + local_buffers = V.local_buffer_context.local_buffers + for local_buffer in local_buffers.values(): + # For dynamic size, rename s to ks + local_buf_size = sympy_product( + [ + self.rename_indexing(size_val) + for size_val in local_buffer.get_layout().size + ] + ) + local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] + allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" + local_buffer_name = local_buffer.get_name() + code.splice( + f"std::unique_ptr<{local_buf_dtype} []> buf_{local_buffer_name} = {allocate};" + ) + code.splice( + f"{local_buf_dtype}* {local_buffer_name} = buf_{local_buffer_name}.get();" + ) + gen_loops(loop_nest.root) + else: + gen_kernel(loop_nest.kernel) + + def codegen_loops(self, code, worksharing): + loop_nest = LoopNestWithSplit.build(self) + self.codegen_loops_impl(loop_nest, code, worksharing) + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models + # compared with JIT Inductor which uses TORCH_CHECK + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def decide_parallel_depth(self, max_parallel_depth, threads): + assert self.call_ranges is not None + ranges = self.call_ranges[:max_parallel_depth] + seq = self.size_hint() + par = 1 + depth = 0 + for expr in ranges: + hint = V.graph.sizevars.size_hint(expr, fallback=8192) + if par >= 2 * threads or par == threads: + break + if seq // threads < config.cpp.min_chunk_size: + # not enough work + break + depth += 1 + par *= hint + seq /= hint + # if we assume thread number is dynamic, make sure we + # have at least one parallel scope and let OMP runtime + # to manage the serial vs. parallel. + if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0: + depth = 1 + return depth + + @contextlib.contextmanager + def write_to_suffix(self): + prior = (self.loads, self.compute, self.stores, self.cse) + self.loads = IndentedBuffer() + self.compute = IndentedBuffer() + self.stores = IndentedBuffer() + self.cse = self.cse.clone() + yield + self.reduction_suffix.splice(self.loads) + self.reduction_suffix.splice(self.compute) + self.reduction_suffix.splice(self.stores) + (self.loads, self.compute, self.stores, self.cse) = prior + + def create_cse_var(self, *args, **kwargs): + return CppCSEVariable(*args, **kwargs) + + def get_to_dtype_expr(self, src, dtype, src_dtype): + return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({src})" + + def cache_dtype_convert(self, dst, dst_dtype, src, src_dtype): + expr = self.get_to_dtype_expr(src, dst_dtype, src_dtype) + self.cse.cache[expr] = dst + + +class CppVecKernel(CppKernel): + overrides = CppVecOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_idx, + tail_size=None, + ): + super().__init__(args, num_threads) + self.vec_isa = cpu_vec_isa.pick_vec_isa() + assert self.vec_isa + assert tiling_factor > 0, "Expect pass in Non-Zero tiling_factor explicitly" + self.tiling_factor = tiling_factor + self.tiling_idx = tiling_idx + self.tail_size = tail_size + self.num_elems = tail_size if tail_size else tiling_factor + + def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol): + if self.index_indirect_depends_on(index, itervar): + return None + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + return None + stride = stride_at_vec_range(index, itervar, self.tiling_factor) + return stride if stride.is_number else None + + def _get_num_vectors(self, dtype: torch.dtype) -> int: + num_vectors = math.ceil( + self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + ) + assert num_vectors >= 1 + return num_vectors + + def _get_raw_num_vectors(self, dtype: torch.dtype) -> float: + # This utility function is used to check if the vector lanes has been + # fully utilized. For example, uint8 will only use 1/4 of the vector lanes. + return self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width() + + def _get_vec_type(self, dtype: torch.dtype) -> str: + num_vectors = self._get_num_vectors(dtype) + if num_vectors == 1: + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>" + else: + return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str: + if dtype == torch.bool: + return "" + num_vectors = self._get_num_vectors(dtype) + return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>" + + def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str: + assert mask.dtype == torch.bool, repr(mask) + num_vectors = self._get_num_vectors(dtype) + return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()" + + def get_reduction_var_pattern(self, line: str): + return re.search("tmp_acc[0-9]+_vec", line) + + def _get_vec_load_line( + self, + var: str, + index: sympy.Expr, + dtype: torch.dtype, + load_mask: Optional[CppCSEVariable] = None, + ): + """ + Get a load line str that loads a vector from `var` at `index` of type `dtype`. + If `load_mask` is not None, we do a masked load accordingly. + Notes on the `dtype`: + 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`. + It means we load half of the vector lanes for 16-bit data types and quarter of the + vector lanes for 8-bit data types. + 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors. + """ + cpp_type = DTYPE_TO_CPP[dtype] + num_vectors = self._get_num_vectors(dtype) + load_mask_str = None + if load_mask: + if not load_mask.is_vec: + # TODO: avoid hard-code torch.float + load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})" + else: + load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}" + loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var + if dtype == torch.bool: + # TODO: should we consider load mask here? + line = f"{self._get_mask_type()}::from({loadbuf})" + else: + line = ( + f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" + if load_mask_str + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" + ) + return line + + def _load_or_store_non_contiguous( + self, + var: Optional[str], + index: sympy.Expr, + dtype: torch.dtype, + buffer: Optional[IndentedBuffer] = None, + store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, + ) -> Optional[CppCSEVariable]: + """ + Load or store a vector in a non-contiguous way. The vector is initialized from an array that is + filled in an inner loop over the tiling factor. + :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index + as index expression, i.e. `transformed(index)`. + :param index: index into the `var` or the index expression by its own if `var` is None. + The `index` could contain indirect indexing or the tiling itervar. When used in + the inner loop, the index is transformed as follows: + 1. the index is linearized along the tiling dim. + 2. the indirect indexing vector variables are transformed into arrays over the tiling dim. + :param dtype: data type of `var` or `index` if `var` is None. + :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. + :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided + :return: a CppCSEVariable that represents the loaded vector or None if it is a store. + """ + assert not store_value or var is not None, "store var must be provided" + if accu_store: + assert store_value + if buffer is None: + buffer = self.loads + + def get_result_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.num_elems * (4 // dtype.itemsize) + else: + return self.num_elems + + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: + assert vec_var.is_vec + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + vec_dtype = vec_var.dtype + assert vec_dtype is not None + if vec_dtype == torch.bool: + vec_dtype = torch.float + result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) + code.writeline( + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" + ) + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" + code.writeline(line) + code.writeline("return tmpbuf;") + code.writeline("()") + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + return csevar + + code = BracesBuffer() + code.writeline("[&]") + with code.indent(): + result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) + result_declare = ( + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" + ) + code.writeline(result_declare) + if store_value: + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) + itervar_inner = sympy_index_symbol( + f"{self.itervars[self.tiling_idx]}_inner" + ) + replacements = {} + for indirect_var in ( + self.cse.varname_map[s.name] # type: ignore[attr-defined] + for s in index.free_symbols + if symbol_is_type(s, SymT.TMP) + ): + assert isinstance(indirect_var, CppCSEVariable) + if indirect_var.is_vec: + array_var = vec_to_array(indirect_var) + replacements[indirect_var] = f"{array_var}[{itervar_inner}]" + index = self.scale_index_with_offset( + index, itervar_idx=self.tiling_idx, offset=itervar_inner + ) + load_mask = None + if self._load_mask is not None: + assert not store_value, "unexpected store with load mask" + assert isinstance(self._load_mask, CppCSEVariable), self._load_mask + if self._load_mask.is_vec: + load_mask = f"{self._load_mask}.is_masked({itervar_inner})" + else: + load_mask = f"{self._load_mask} != 0" + if cpp_builder.is_gcc(): + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") + else: + code.writeline(f"#pragma unroll {self.tiling_factor}") + code.writeline( + f"for (long {itervar_inner} = 0; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + + f"{itervar_inner}++)" + ) + with code.indent(), contextlib.ExitStack() as stack: + index_c = cexpr_index(index) + for indirect_var in replacements: + index_c = re.sub( + r"\b" + f"{indirect_var}" + r"\b", + replacements[indirect_var], + index_c, + ) + rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}" + if load_mask: + code.writeline(f"if ({load_mask})") + stack.enter_context(code.indent()) + if store_value: + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") + else: + code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") + if not store_value: + load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type] + code.writeline(f"return {load_line};") + code.writeline("()") + if store_value: + code.writeline(";") + buffer.splice(code) + return None + else: + csevar = self.cse.generate(buffer, code) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + tiling_var = self.itervars[self.tiling_idx] + stride = self._try_get_const_stride(index, tiling_var) + if stride == 0: + # load scalar and lazily broadcast it on demand + return super().load(name, index) + elif stride == 1: + # load contiguously + line = self._get_vec_load_line(var, index, dtype, self._load_mask) + csevar = self.cse.generate(self.loads, line) # type: ignore[assignment] + else: + csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment] + assert isinstance(csevar, CppCSEVariable) + csevar.update_on_args("load", (self, name, index), {}) + csevar.is_vec = True + return csevar + + def _get_store_line( + self, + value: Union[str, CppCSEVariable], + var: str, + index: sympy.Expr, + dtype: torch.dtype, + accu_store: bool = False, + ): + """ + Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles + both contiguous and non-contiguous store cases. + :param value: Vectorized type templaterized on `dtype`. + :param var: buffer to store into. + :index: index into the `var`. + """ + # when value's type is str (e.g., welford reduction), caller should make sure + # it is a vector + assert isinstance(value, str) or ( + isinstance(value, CppCSEVariable) and value.is_vec + ), value + tiling_var = self.itervars[self.tiling_idx] + var_expr = f"{var} + {cexpr_index(index)}" + stride = self._try_get_const_stride(index, tiling_var) + code = IndentedBuffer() + if stride == 1: + if dtype == torch.float and self.tail_size is None: + code.writeline(f"{value}.store({var_expr});") + else: + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) + else: + self._load_or_store_non_contiguous( + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store + ) + return code + + def store(self, name, index, value, mode=None): + assert "buf" in name + assert isinstance(value, CppCSEVariable), value + if not value.is_vec: + # this happens when we store a scalar into a vectorized buffer like "fill" + value = self.broadcast(value) + var = self.args.output(name) + index = self.rename_indexing(index) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert index.is_vec + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") + + def reduction(self, dtype, src_dtype, reduction_type, value): + assert reduction_type in VECTORIZABLE_RTYPES + argmax_or_argmin = reduction_type in {"argmax", "argmin"} + horizontal_reduction = self.tiling_idx >= self.reduction_depth + init_dtype = src_dtype if argmax_or_argmin else dtype + assert isinstance(value, CppCSEVariable), value + + if not value.is_vec: + value = self.broadcast(value) + + reduction_key = src_dtype, reduction_type, value + if reduction_key in self.reduction_cse.reduction_cache: + return self.reduction_cse.reduction_cache[reduction_key] + + vec_ns = "at::vec" + vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" + acc_type = reduction_acc_type(reduction_type, init_dtype) + acc_type_vec = self.reduction_acc_type_vec(reduction_type, init_dtype) + + acc = self.reduction_cse.generate( + self.loads, f"reduction {reduction_key}", write=False + ) + acc_vec = f"{acc}_vec" + self.is_reduction = True + self.reduction_prefix.writeline( + f"{acc_type} {acc} = {reduction_init(reduction_type, init_dtype)};" + ) + self.reduction_prefix.writeline( + f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, init_dtype)};" + ) + if reduction_type == "welford_reduce": + # save the reciprocal of weights for welford reduce + assert self.reduction_depth is not None + # use masked acc_vec for tail vec kernel + self.reduction_prefix.writeline( + f"{acc_type_vec} masked_{acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};" + ) + reduction_size = functools.reduce( + lambda x, y: x * y, self.ranges[self.reduction_depth :] + ) + reduction_factor = ( + self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1 + ) + self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor) + if self.weight_recp_vec_range not in self.weight_recps_cse.reduction_cache: + self.weight_recps_val = self.weight_recps_cse.generate( + self.compute, f"reduction {self.weight_recp_vec_range}", write=False + ) + self.weight_recps_cse.reduction_cache[ + self.weight_recp_vec_range + ] = self.weight_recps_val + self.non_parallel_reduction_prefix.writeline( + self.welford_weight_reciprocal_vec(dtype) + ) + # generate weight_recps for parallel reduction + num_threads = ( + "max_threads" + if config.cpp.dynamic_threads + else parallel_num_threads() + ) + self.local_reduction_init.writeline( + self.welford_weight_reciprocal_vec(dtype, num_threads) + ) + else: + self.weight_recps_val = self.weight_recps_cse.reduction_cache[ + self.weight_recp_vec_range + ] + # use masked acc_vec for tail vec kernel + acc_vec_ = f"masked_{acc_vec}" if self.tail_size else acc_vec + self.stores.writeline( + f"{acc_vec_} = {self.reduction_combine_vec(reduction_type, acc_vec_, value, True)};" + ) + else: + assert self.reduction_depth is not None + index = self.itervars[self.reduction_depth] + for i in range(self.reduction_depth + 1, len(self.itervars)): + index = index * self.ranges[i] + self.itervars[i] + combine = self.reduction_combine_vec( + reduction_type, + acc_vec, + value, + index=index, + horizontal_reduction=horizontal_reduction, + src_dtype=src_dtype, + ) + self.stores.writeline(f"{acc_vec} = {combine};") + self._gen_parallel_reduction_buffers( + acc, + acc_type, + reduction_type, + init_dtype, + ) + self._gen_parallel_reduction_buffers( + acc_vec, + acc_type_vec, + reduction_type, + init_dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + if reduction_type == "welford_reduce": + # use masked acc_vec for tail vec kernel + self._gen_parallel_reduction_buffers( + f"masked_{acc_vec}", + acc_type_vec, + reduction_type, + dtype, + reduction_combine_fn=self.reduction_combine_vec, + reduction_init_fn=self.reduction_init_vec, + ) + tmpvar: Union[str, CSEVariable] + is_bool = dtype == torch.bool + if horizontal_reduction: + # Horizontal reduction + if is_welford_reduction(reduction_type): + assert self._get_num_vectors(dtype) in [ + 1, + 2, + ], "Welford reduction does not support VectorizedN (N>2)" + next_value = f"welford_vec_reduce_all({acc_vec})" + masked_next_value = f"welford_vec_reduce_all(masked_{acc_vec})" + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, masked_next_value)};" + ) + elif argmax_or_argmin: + next_value = f"{reduction_type}_vec_reduce_all({acc_vec})" + elif is_bool: + if reduction_type in ( + "any", + "sum", + "max", + ): + next_value = f"!{acc_vec}.all_zero()" + else: + assert reduction_type == "min" + next_value = f"{acc_vec}.all_masked()" + else: + reduce_all_body = ( + "{ return " + + self.reduction_combine_vec(reduction_type, "x", "y") + + "; }" + ) + is_bool = dtype == torch.bool + # we are using at::vec::VecMask for bool + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" + next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" + + self.reduction_suffix.writeline( + f"{acc} = {reduction_combine(reduction_type, acc, next_value, src_dtype=src_dtype)};" + ) + tmpvar = acc + else: + tmpvar = acc_vec + if is_welford_reduction(reduction_type): + masked_tmpvar = f"masked_{tmpvar}" + self.reduction_suffix.writeline( + f"{tmpvar} = {reduction_combine(reduction_type, tmpvar, masked_tmpvar)};" + ) + + result = reduction_project(reduction_type, tmpvar) + self.reduction_cse.reduction_cache[reduction_key] = result + return result + + def store_reduction(self, name, index, value): + index = self.rename_indexing(index) + var = self.args.output(name) + out_dtype = V.graph.get_dtype(name) + dtype = ( + (out_dtype if out_dtype == torch.double else torch.float) + if out_dtype.is_floating_point + else torch.int64 + ) + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) + code = IndentedBuffer() + if self.tiling_idx >= self.reduction_depth: + # Horizontal reduction + code.writeline( + f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});" + ) + else: + # Vertical reduction + if out_dtype != dtype: + converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) + code.writeline(f"auto {converted_value} = {convert};") + value = converted_value + code.splice(self._get_store_line(value, var, index, out_dtype)) + self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) + + def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable: + assert not scalar_var.is_vec + if scalar_var.dtype == torch.bool: + vec_var = self.cse.generate( + self.compute, f"{self._get_mask_type()}::from({scalar_var.name})" + ) + else: + assert scalar_var.dtype is not None + vec_var = self.cse.generate( + self.compute, + f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})", + ) + assert isinstance(vec_var, CppCSEVariable) + vec_var.dtype = scalar_var.dtype + vec_var.dependent_itervars = scalar_var.dependent_itervars + vec_var.is_vec = True + return vec_var + + def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable: + assert not index.is_vec + assert index.dtype is not None + csevar = self.cse.generate( + self.compute, + f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})", + ) + assert isinstance(csevar, CppCSEVariable) + csevar.dtype = index.dtype + csevar.is_vec = True + return csevar + + def reduction_init_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>()" + + if reduction_type in {"argmin", "argmax"}: + cdtype = DTYPE_TO_CPP[scalar_type] + acc_type = self.reduction_acc_type_vec(reduction_type, dtype) + if reduction_type == "argmin": + val = ( + f"std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::max()" + ) + else: + val = ( + f"-std::numeric_limits<{cdtype}>::infinity()" + if is_float_dtype(dtype) + else f"std::numeric_limits<{cdtype}>::min()" + ) + return f"{acc_type}({val})" + + if reduction_type == "any": + return f"{self._get_mask_type()}::from(0)" + + scalar_init = reduction_init(reduction_type, dtype) + vec_init = f"{vec_type}({scalar_init})" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "sum") + return f"{self._get_mask_type()}::from({scalar_init})" + return vec_init + + def reduction_acc_type_vec(self, reduction_type, dtype): + scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype] + vec_type = self._get_vec_type(scalar_type) + if is_welford_reduction(reduction_type): + return f"Welford<{vec_type}>" + if reduction_type in {"argmin", "argmax"}: + n_src = self._get_num_vectors(scalar_type) + n_idx = self._get_num_vectors(torch.int64) + return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" + if dtype == torch.bool: + assert reduction_type in ("min", "max", "any", "sum") + return f"{self._get_mask_type()}" + return vec_type + + def welford_weight_reciprocal_vec(self, dtype, num_threads=None): + vec_num_range_thread = ( + CeilDiv(self.weight_recp_vec_range, num_threads) + if num_threads + else self.weight_recp_vec_range + ) + vec_num_range_thread_expr = cexpr_index(vec_num_range_thread) + return ( + f"static WeightRecp<{self._get_vec_type(dtype)}> {self.weight_recps_val}" + f"(" + f"{vec_num_range_thread_expr}" + f");" + ) + + def reduction_combine_vec( + self, + reduction_type, + var, + next_value, + use_weight_recps=False, + index: Optional[sympy.Symbol] = None, + horizontal_reduction: Optional[bool] = None, + src_dtype: Optional[torch.dtype] = torch.float32, + ): + is_bool = src_dtype == torch.bool + if reduction_type == "max": + if self.tail_size: + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} | {next_value}" + if is_bool + else f"at::vec::maximum({var}, {next_value})" + ) + elif reduction_type == "min": + if self.tail_size: + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return ( + f"{var} & {next_value}" + if is_bool + else f"at::vec::minimum({var}, {next_value})" + ) + elif reduction_type == "sum": + if self.tail_size: + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + conjunction = "|" if is_bool else "+" + return f"{var} {conjunction} {next_value}" + elif reduction_type == "prod": + if self.tail_size: + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} * {next_value}" + elif reduction_type == "xor_sum": + if self.tail_size: + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"{var} ^ {next_value}" + elif reduction_type == "welford_reduce": + if use_weight_recps: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{self.weight_recps_val})" + else: + return f"welford_combine({var}, {next_value}, &{self.weight_recps_val})" + else: + if self.tail_size: + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {next_value})" + elif reduction_type == "welford_combine": + if isinstance(next_value, tuple): + # When reading a value from Inductor IR we have a tuple of variable names + mean, m2, weight = next_value + else: + # When combining intermediate accumulators we have a Welford struct + mean, m2, weight = reduction_project(reduction_type, next_value) + if self.tail_size: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" + else: + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" + elif reduction_type in ("argmin", "argmax"): + assert src_dtype is not None + cdtype = DTYPE_TO_CPP[src_dtype] + n_src = self._get_num_vectors(src_dtype) + n_idx = self._get_num_vectors(torch.int64) + t_extra = "" + arg_extra = "" + if index is not None: + assert horizontal_reduction is not None + t_extra = f", {str(horizontal_reduction).lower()}" + arg_extra = f", {index}" + if self.tail_size: + return ( + f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" + ) + else: + return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" + elif reduction_type == "any": + return f"{var} | {next_value}" + else: + raise NotImplementedError + + def indirect_assert(self, var, lower, upper, mask=None): + assert isinstance(var, CppCSEVariable) + assert var.dtype is not None + if not var.is_vec: + if isinstance(mask, CppCSEVariable) and mask.is_vec: + mask = f"({mask}).all_masked()" + return super().indirect_assert(var, lower, upper, mask) + lower_scalar = lower + upper_scalar = upper + if lower: + lower = f"{self._get_vec_type(var.dtype)}({lower})" + if upper: + upper = f"{self._get_vec_type(var.dtype)}({upper})" + if lower and upper: + cond = f"({lower} <= {var}) & ({var} < {upper})" + cond_print = f"{lower_scalar} <= {var} < {upper_scalar}" + elif lower: + cond = f"{lower} <= {var}" + cond_print = f"{lower_scalar} <= {var}" + else: + assert upper + cond = f"{var} < {upper}" + cond_print = f"{var} < {upper_scalar}" + cond = f"{self._get_mask_type(var.dtype)}({cond})" + if mask: + if not mask.is_vec: + mask = f"{self._get_mask_type(var.dtype)}({mask})" + # We need not check when the mask is False + cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) + cond = f"({cond}).all_masked()" + return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' + + def get_to_dtype_expr(self, src, dtype, src_dtype): + assert isinstance(src, CppCSEVariable) + if not src.is_vec: + return super().get_to_dtype_expr(src, dtype, src_dtype) + src_cpp_type = DTYPE_TO_CPP[src_dtype] + src_num_vectors = self._get_num_vectors(src_dtype) + dst_cpp_type = DTYPE_TO_CPP[dtype] + dst_num_vectors = self._get_num_vectors(dtype) + expr = f"({src})" + if src_dtype != torch.bool and dtype == torch.bool: + expr = f"{self._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({src})" + elif src_dtype == torch.bool and dtype != torch.bool: + expr = f"{src}.to<{dst_cpp_type},{dst_num_vectors}>()" + elif src_dtype != dtype: + if src_num_vectors == dst_num_vectors == 1: + expr = f"at::vec::convert<{dst_cpp_type}>({src})" + else: + expr = f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({src})" + return expr + + +class CppTile2DKernel(CppVecKernel): + """ + A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on + the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data + tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the + tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization + logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load + and store are generated into kernel.preloads and kernel.poststores buffers. + + The loop structure looks like below: + for ... + for i_outer ... + for ... + for inner_most ... + // generated by CppTile2DKernel + float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads + float tmp1[16*16]; // into kernel.preloads + for i_inner ... { // the kernel inner loop + vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores + } + at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores + for inner_most ... (tail) + // generated by CppVecKernel + ... + for i_outer ... (tail) + for ... + for ... + // generated by CppKernel + ... + """ + + overrides = CppTile2DOverrides # type: ignore[assignment] + + def __init__( + self, + args, + num_threads, + tiling_factor, + tiling_indices, + inner_tail_size=None, + outer_tail_size=None, + ): + super().__init__( + args, + num_threads, + tiling_factor, + tiling_indices[1], + inner_tail_size, + ) + self.tiling_indices = tiling_indices + self.inner_tail_size = inner_tail_size + self.outer_tail_size = outer_tail_size + self.inner_num_elems = inner_tail_size if inner_tail_size else tiling_factor + self.outer_num_elems = outer_tail_size if outer_tail_size else tiling_factor + self.inner_is_tiling_idx = True + + def inner_itervar(self): + return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner") + + def need_vec_transpose(self, index): + outer_var = self.itervars[self.outer_idx] + inner_var = self.itervars[self.tiling_idx] + outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) + inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) + return ( + self._load_mask is None # TODO: support transposition with mask + and outer_stride == 1 + and index.has(inner_var) + and not inner_stride.has(inner_var) + and not inner_stride.has(outer_var) + ) + + def gen_transposed_tile_load_store(self, name, var, index, is_store): + # transposed tile load/store outside the kernel inner loop + dtype = V.graph.get_dtype(name) + factor = self.tiling_factor + src = f"{var} + {cexpr_index(index)}" + dst = "__place_holder__" + ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" + ld_dst = f"{cexpr_index(self.num_elems)}" + if is_store: + src, dst = dst, src + ld_src, ld_dst = ld_dst, ld_src + + need_define = True + if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): + load_or_store = ( + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" + ) + else: + load_or_store = ( + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>" + f"({src}, {ld_src}, {dst}, {ld_dst});" + ) + if is_store: + tile_var = self.cse.newvar() + elif load_or_store not in self.cse.cache: + tile_var = self.cse.generate(self.preloads, load_or_store, write=False) + else: + need_define = False + tile_var = self.cse.cache[load_or_store] + + if need_define: + define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];" + self.preloads.writeline(define_line) + + load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) + if is_store: + self.poststores.writeline(DeferredLine(name, load_or_store)) + else: + self.preloads.writeline(load_or_store) + + return tile_var + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + + inner = self.inner_itervar() + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=False + ) + # vector load inside the kernel inner loop + loadbuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + dtype = V.graph.get_dtype(name) + line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type] + csevar = self.cse.generate(self.loads, line) + csevar.update_on_args("load", (self, name, index), {}) + assert isinstance(csevar, CppCSEVariable) + csevar.is_vec = True + return csevar + else: + new_index = self.transform_indexing(index) + return super().load(name, new_index) + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + + inner = self.inner_itervar() + index = self.rename_indexing(index) + assert mode is None + if self.need_vec_transpose(index): + tile_var = self.gen_transposed_tile_load_store( + name, var, index, is_store=True + ) + # vector store inside the kernel inner loop + storebuf = f"{tile_var} + {cexpr_index(inner * self.num_elems)}" + if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [ + torch.uint8, + torch.int8, + ]: + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" + else: + line = f"{value}.store({storebuf});" + self.stores.writeline(DeferredLine(name, line)) + else: + new_index = self.transform_indexing(index) + super().store(name, new_index, value, mode) + + def codegen_inner_loops(self, code): + inner = self.inner_itervar() + if self.inner_is_tiling_idx: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" + ) + else: + code.writeline( + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" + ) + + def set_ranges(self, group, reduction_group): + vars = super().set_ranges(group, reduction_group) + # do vertical reduction as the tail loop + self.outer_idx, self.tiling_idx = ( + self.tiling_indices + if self.tiling_indices[1] < self.reduction_depth + else reversed(self.tiling_indices) + ) + if self.tiling_idx == self.tiling_indices[0]: + self.tail_size = self.outer_tail_size + self.num_elems = self.outer_num_elems + self.inner_is_tiling_idx = False + else: + self.tail_size = self.inner_tail_size + self.num_elems = self.inner_num_elems + self.inner_is_tiling_idx = True + return vars + + def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: + return self.scale_index_with_offset( + index, + itervar_idx=self.outer_idx, + offset=self.inner_itervar(), + ) + + +def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]: + """ + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. + """ + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + + _lowp_fp_type: Optional[torch.dtype] = None + _use_fp32 = False + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.op == "placeholder" or _node.target in ( + "get_index", + "index_expr", + ): + continue + + # Fast path if all operations can support bf16/fp16 without converting to fp32 + if _node.target not in [ + "load", + "store", + "abs", + "neg", + "output", + ]: + _use_fp32 = True + + if hasattr(_node, "meta") and _node.meta: + assert OptimizationContext.key in _node.meta + opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] + if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") + else: + _lowp_fp_type = opt_ctx.dtype + else: + _use_fp32 = True + + return _lowp_fp_type, _use_fp32 + + +class TilingSelect: + """ + Implement the heuristic to select the tiling factors and tiling indices. + In the future, we can implement advanced heuristic in a subclass. + """ + + def __init__(self): + super().__init__() + + def select_tiling( + self, + fn_list, + var_sizes_list, + ) -> Tuple[List[int], List[int]]: + # TODO(jgong5): support alternative tiling factors and data types + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] + dtype = torch.float + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] + if _lowp_fp_dtype and all( + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) + for loop_body in loop_bodies[1:] + ): + dtype = _lowp_fp_dtype + + tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + tiling_indices = self._select_tiling_indices( + fn_list, var_sizes_list, tiling_factor + ) + + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + + if config.cpp.enable_tiling_heuristics: + + def _try_get_stride( + index, + itervars, + tiling_factor, + tiling_indices, + ): + itervar = itervars[tiling_indices[0]] + stride = stride_at_vec_range(index, itervar, tiling_factor) + return stride if stride.is_number else None + + def _update_negative_op_count( + node_name, non_contig_indexing_op_counter + ): + if node_name not in non_contig_indexing_op_counter: + non_contig_indexing_op_counter[node_name] = 1 + else: + non_contig_indexing_op_counter[node_name] += 1 + + def _is_valid_indices( + itervars, + tiling_indices, + ): + return ( + len(tiling_indices) == 1 + and len(itervars) > 0 + and ( + tiling_indices[0] + if tiling_indices[0] >= 0 + else tiling_indices[0] + len(itervars) + ) + < len(itervars) + ) + + itervars = [ + sympy_index_symbol_with_prefix(SymT.XBLOCK, n) + for n in range(len(call_ranges)) + ] + reduction_depth = len(group) + vars, reduction_vars = ( + itervars[:reduction_depth], + itervars[reduction_depth:], + ) + op_counter: Dict[str, int] = {} + # ops may cause overhead with vectorization, like non-contiguous + # index_expr, load, store + non_contig_indexing_op_counter: Dict[str, int] = {} + for _body in loop_bodies: + sub_blocks = [_body.root_block] + list(_body.subblocks.values()) + for sub_block in sub_blocks: + for _node in sub_block.graph.nodes: + if _node.target in ["index_expr", "load", "store"]: + # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 + index = sub_block.body.indexing_from_args( + (vars, reduction_vars) + )[_node.args[arg_idx].args[0]] + if _is_valid_indices(itervars, tiling_indices): + stride = _try_get_stride( + index, itervars, tiling_factor, tiling_indices + ) + if ( + stride is None + if _node.target == "index_expr" + else stride not in [0, 1] + ): + _update_negative_op_count( + _node.target, non_contig_indexing_op_counter + ) + if isinstance(_node.target, str) and not ( + _node.target.startswith("masked_subblock") + or _node.target + in ["ops", "output", "constant", "get_index"] + ): + if _node.target not in op_counter: + op_counter[_node.target] = 1 + else: + op_counter[_node.target] += 1 + + op_num = sum(op_counter.values()) + non_contig_indexing_op_num = sum( + non_contig_indexing_op_counter.values() + ) + threshold = 0.08 + if op_num > 0 and non_contig_indexing_op_num / op_num >= threshold: + # Too many non-contiguous load/store/index_expr which hurts the + # vectorization performance. Disable vectorization when exceeding + # the threshold. + return [], [] + + if ( + not reduction_group + and group + and len(tiling_indices) == 1 + and not has_free_symbols( + [ + group[tiling_indices[0]], + ] + ) + and group[tiling_indices[0]] < tiling_factor / 2 + ): + # For case of Multi Thread AMP Static shape of pyhpc_isoneutral_mixing, + # the inner loop range doesn't have enough elements to do vectorization + # explicitly and found that `#pragma GCC ivdep` has better performance than + # `#pragma omp simd simdlen(8)`. Disable vectorization for this case. + # Leslie: maybe we can always disable vectorization when loop range is less + # than tiling factor and enable `#pragma omp simd simdlen(8)` for scalar kernel + # when needed. + return [], [] + + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.guard_lt(call_range, factor_lowp) + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + + if len(tiling_indices) == 1: + return [tiling_factor], tiling_indices + if len(tiling_indices) == 2: + return [tiling_factor, tiling_factor], tiling_indices + return [], [] + + def _select_tiling_indices( + self, + fn_list, + var_sizes_list, + tiling_factor, + ): + all_index = [] + for fn, var_sizes in zip(fn_list, var_sizes_list): + rw = dependencies.extract_read_writes(fn, *var_sizes) + all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)] + contig_vars = set() + contig_vars_list = [] + non_contig_stride_const = set() + non_contig_stride_other = set() + for index in all_index: + for var in index.free_symbols: + if not re.search(r"^d\d+$", var.name): + continue + stride = stride_at_vec_range(index, var, tiling_factor) + if stride == 0: + continue + elif stride == 1: + contig_vars.add(int(var.name[1:])) + contig_vars_list.append(int(var.name[1:])) + elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols): + non_contig_stride_const.add(int(var.name[1:])) + else: + non_contig_stride_other.add(int(var.name[1:])) + contig_only = contig_vars - non_contig_stride_const - non_contig_stride_other + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + num_itervars = len(group) + len(reduction_group) + if len(contig_vars) == 0: + # no contiguous vars + return [num_itervars - 1] + if contig_only: + return sorted(contig_only)[-1:] + contig_and_const_stride = ( + contig_vars & non_contig_stride_const + ) - non_contig_stride_other + contig_vars_sorted = sorted(contig_vars) + if ( + len(contig_vars_sorted) == 2 + and contig_vars_sorted[-1] in contig_and_const_stride + and contig_vars_sorted[-1] == num_itervars - 1 + ): + return contig_vars_sorted + return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:] + + +class CppKernelProxy(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.kernel_group = kernel_group + self.loop_nest = None + self.call_ranges = None + self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + + def data_type_propagation(self, nodes): + for _node in nodes: + assert isinstance(_node, SchedulerNode) + DataTypePropagation.propagate_scheduler_node(_node) + + # Check if all the nodes of a given fx graph can support BF16/FP16 + def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): + if not isinstance(scheduler_node._body, LoopBody): + return True + # Propagate the dtype to check if all the fx node is bf16/fp16 + DataTypePropagation.propagate_scheduler_node(scheduler_node) + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) + + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): + def add_to_dtype(sub_graph: torch.fx.Graph): + def is_lowp_fp_load(node: torch.fx.Node): + if node.target not in ["load"]: + return False + assert len(node.args) == 3 + load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type] + return load_dtype in DTYPE_LOWP_FP + + def is_lowp_fp_store(node: torch.fx.Node): + if node.target != "store": + return False + _, store_var, _, _, _ = node.args + store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type] + return store_dtype in DTYPE_LOWP_FP + + sub_graph_nodes = list(sub_graph.nodes) + to_lowp_fp_legalized_nodes = [] + for _node in sub_graph_nodes: + if is_lowp_fp_load(_node): + # No need to promote to float if all users are direct stores + if all(user.target == "store" for user in _node.users): + continue + ops = _node.args[0] + with sub_graph.inserting_after(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, _node, torch.float) + ) + to_type_node_args = to_type_node.args + _node.replace_all_uses_with(to_type_node) + to_type_node.args = to_type_node_args + metrics.cpp_to_dtype_count += 1 + elif is_lowp_fp_store(_node): + ops, name, _, value_var, _ = _node.args + # No need to promote to float if it is a user of a load which are all directly stored + if value_var.target == "load" and all( + user.target == "store" for user in value_var.users + ): + continue + dtype = V.graph.get_dtype(name) + with sub_graph.inserting_before(_node): + to_type_node = sub_graph.call_method( + "to_dtype", args=(ops, value_var, dtype) + ) + _node.replace_input_with(value_var, to_type_node) + metrics.cpp_to_dtype_count += 1 + elif _node.target == "reduction": + ( + ops, + dtype, + src_dtype, + reduction_type, + value, + ) = _node.args + if src_dtype in DTYPE_LOWP_FP: + # Since we always convert the load/store value to float if the tensor is bfloat16/float16. + # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update + # the bfloat16/float16 reduction by + # 1) updating the src_dtype to float + # and 2) updating the dtype to float if it is bfloat16/float16. + assert dtype in [ + torch.float, + torch.bfloat16, + torch.float16, + torch.int64, + ] + _node.args = ( + ops, + torch.float if dtype in DTYPE_LOWP_FP else dtype, + torch.float, + reduction_type, + value, + ) + elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP: + (ops, x, _) = _node.args + # The legalization always loads the BF16/FP16 tensor as FP32 for computation + # and converts back to BF16/FP16 after the computation. + # Hence, there should be no computation w/ BF16/FP16. + # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32. + # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step): + # 1) Eliminate the redundant to_dtype node if we have a pattern as follows: + # graph(): + # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float)) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16)) + # Regarding the first to_dtype, it is redundant because + # the second to_type also converts to the torch.bfloat16/torch.float16. + # Hence, we remove the first to_type. + to_lowp_fp_legalized_nodes.append(_node) + _node.args = (ops, x, torch.float) + else: + pass + + def eliminate_to_dtype(sub_graph: torch.fx.Graph): + def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph): + # Eliminate the redundant to_dtype node. Let's consider a pattern as follows: + # graph(): + # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {}) + # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {}) + # Regarding the first to_dtype, it is redundant because the second to_type also converts to the + # torch.float. Hence, we remove the first to_type + def _used_by_to(to_node: torch.fx.Node): + return all(usr.target == "to_dtype" for usr in to_node.users) + + all_to_nodes = [ + node for node in sub_graph.nodes if node.target == "to_dtype" + ] + all_to_nodes_and_users = [ + {node: node.users} for node in all_to_nodes if _used_by_to(node) + ] + for node_users in all_to_nodes_and_users: + for node, users in node_users.items(): + if node in sub_graph.nodes and ( + all(usr.args[-1] == node.args[-1] for usr in users) + or ( + node in to_lowp_fp_legalized_nodes + and all( + usr.args[-1] in DTYPE_LOWP_FP for usr in users + ) + ) + ): + val_node = node.all_input_nodes[-1] + node.replace_all_uses_with(val_node) + sub_graph.erase_node(node) + + # For debug mode, the graph of LoopBody will attach a new GraphModule as + # owning_module for debugging while the release mode will not. The lint will + # check whether the graph has owning_module to decide if it needs to check + # call_module. LoopBody might contain get_index as a module call. But it + # is just a function. Hence, it cannot pass the lint check for debug mode. + # We bypass the check if the owning_module is None. Eventually, we should call + # get_index via call_function but not call_module. + if sub_graph.owning_module is None: + sub_graph.lint() + + _eliminate_duplicate_to_node(sub_graph) + + eliminate_to_dtype(sub_graph) + + sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values()) + for sub_block in sub_blocks: + add_to_dtype(sub_block.graph) + + def legalize_lowp_fp_dtype(self, nodes): + if all( + isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node) + for _node in nodes + ): + # Mark the load node to load bf16/fp16 + for _node in nodes: + sub_blocks = [_node._body.root_block] + list( + _node._body.subblocks.values() + ) + for sub_block in sub_blocks: + for fx_node in sub_block.graph.nodes: + if fx_node.target in ["load", "store"]: + assert fx_node.meta + assert OptimizationContext.key in fx_node.meta + opt_ctx: OptimizationContext = fx_node.meta[ + OptimizationContext.key + ] + assert opt_ctx.dtype in DTYPE_LOWP_FP + + # Bypass the legalization as the kernel can run with bf16/fp16 directly + return + + for _node in nodes: + assert isinstance(_node, SchedulerNode) + assert isinstance(_node._body, LoopBody) + body: LoopBody = _node._body + if not body.is_memory_copy(): + self.legalize_lowp_fp_dtype_loopbody(body) + + def codegen_functions(self, fn_list, var_sizes_list): + assert len(fn_list) == len(var_sizes_list) + kernel_group = self.kernel_group + group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1])) + + self.set_ranges(group, reduction_group) + + def codegen_kernel(cls, *args): + with kernel_group.new_kernel(cls, *args) as kernel: + # Ugly hack to maintain the metrics kernel count since + # we only count in CppKernelProxy, not those contained in it + metrics.generated_kernel_count -= 1 + + run(kernel) + return kernel + + def run(kernel): + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + in_suffix = False + for fn, var_sizes in zip(fn_list, var_sizes_list): + if var_sizes in [ + (group, reduction_group), + (tuple(itertools.chain(group, reduction_group)), ()), + ]: + assert not in_suffix + fn(vars, reduction_vars) + else: + in_suffix = True + assert var_sizes == ( + group, + (), + ), f"unexpected group: {var_sizes} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + fn(vars, ()) + + scalar_kernel = codegen_kernel(CppKernel) + V.graph.removed_buffers |= scalar_kernel.removed_buffers + V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove + self.loop_nest = LoopNestWithSplit.build(scalar_kernel) + + if not self.picked_vec_isa: + return + + if not self.itervars: + # not a loop + return + + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + # But the generated scalar kernel has updated these global contexts. Hence, the other kernels + # should not do this again to avoid context conflict. By now, we only control the + # config.inplace_buffers. In the future, we could maintain more contexts. + with torch._inductor.config.patch(inplace_buffers=False): + tiling_select = TilingSelect() + tiling_factors, tiling_indices = tiling_select.select_tiling( + fn_list, var_sizes_list + ) + assert len(tiling_factors) == len(tiling_indices) + # This should be removed after full support for vectorization is implemented. + could_masked_vec = True + all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) + if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): + # can be removed after masked vectorizable dtype are same with vectorizable dtype + could_masked_vec = False + + if len(tiling_indices) == 1: + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0] + ) + metrics.generated_cpp_vec_kernel_count += 1 + main_loop, tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + main_loop.set_kernel(vec_kernel) + main_loop.simd_vec = True + if config.cpp.enable_loop_tail_vec and could_masked_vec: + tail_loop.steps = tail_loop.size - tail_loop.offset + masked_vec_kernel = codegen_kernel( + CppVecKernel, + tiling_factors[0], + tiling_indices[0], + tail_loop.steps, + ) + tail_loop.set_kernel(masked_vec_kernel) + tail_loop.simd_vec = True + else: + tail_loop.set_kernel(scalar_kernel) + tail_loop.simd_omp = True + # We chop the loop into two cubes by the nelements - main loop and tail loop. + # Regarding the main loop, it is straightforward that it could be vectorized with + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized + # as 4(128bits). + tail_loop.simd_nelements = tiling_factors[0] // 2 + elif len(tiling_indices) == 2: + assert ( + tiling_indices[1] == len(self.itervars) - 1 + and tiling_factors[0] == tiling_factors[1] + ) + + metrics.generated_cpp_vec_kernel_count += 2 + outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling( + tiling_indices[0], factor=tiling_factors[0] + ) + ( + inner_main_loop, + inner_tail_loop, + ) = outer_main_loop.split_with_tiling( + tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] + ) + tile2d_kernel = codegen_kernel( + CppTile2DKernel, tiling_factors[0], tiling_indices + ) + inner_main_loop.set_kernel(tile2d_kernel) + + if config.cpp.enable_loop_tail_vec and could_masked_vec: + ( + inner_main_loop_of_outer_tail_loop, + inner_tail_loop_of_outer_tail_loop, + ) = outer_tail_loop.split_with_tiling( + tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0] + ) + + for tail_loop in ( + inner_tail_loop, + outer_tail_loop, + inner_tail_loop_of_outer_tail_loop, + ): + tail_loop.steps = tail_loop.size - tail_loop.offset + + for tail_loop, inner_tail_size, outer_tail_size in ( + (inner_tail_loop, inner_tail_loop.steps, None), + ( + inner_main_loop_of_outer_tail_loop, + None, + outer_tail_loop.steps, + ), + ( + inner_tail_loop_of_outer_tail_loop, + inner_tail_loop_of_outer_tail_loop.steps, + outer_tail_loop.steps, + ), + ): + masked_tile2d_kernel = codegen_kernel( + CppTile2DKernel, + tiling_factors[0], + tiling_indices, + inner_tail_size, + outer_tail_size, + ) + tail_loop.set_kernel(masked_tile2d_kernel) + else: + vec_kernel = codegen_kernel( + CppVecKernel, tiling_factors[0], tiling_indices[0] + ) + inner_tail_loop.set_kernel(vec_kernel) + + outer_tail_loop.set_kernel(scalar_kernel) + + def codegen_loop_bodies(self, loop_bodies, var_sizes_list): + for body in loop_bodies: + self.legalize_lowp_fp_dtype_loopbody(body) + DataTypePropagation.propagate_loopbody(body) + self.codegen_functions(loop_bodies, var_sizes_list) + + def codegen_nodes(self, nodes: List[SchedulerNode]): + # Legalize BF16 node by adding to_dtype explicitly + self.legalize_lowp_fp_dtype(nodes) + self.data_type_propagation(nodes) + assert len(nodes) >= 1 + + def fn(node, *index_vars): + node.decide_inplace_update() + node.mark_run() + if isinstance(V.kernel, NullKernelHandler): + return node._body(*index_vars) + else: + return node.codegen(index_vars) + + fn_list = [functools.partial(fn, node) for node in nodes] + + if ( + isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + + def wrap_fn(fn): + wrapped_fn = V.local_buffer_context.localize_function( + fn, + ) + wrapped_fn.original_fn = fn + return wrapped_fn + + fn_list = [wrap_fn(fn) for fn in fn_list] + + var_sizes_list = [node.group[1] for node in nodes] + self.codegen_functions(fn_list, var_sizes_list) + + def codegen_loops(self, code, worksharing): + self.codegen_loops_impl(self.loop_nest, code, worksharing) + + +class OuterLoopFusedKernel(CppKernel): + def __init__(self, kernel_group): + super().__init__(kernel_group.args, kernel_group.ws.num_threads) + self.inner: List[LoopLevel] = [] + + def decide_parallel_depth(self, max_parallel_depth, threads) -> int: + kernels_parallel_depth = [] + nested_kernels: List[List[CppKernel]] = [ + loop.get_kernels() for loop in self.inner + ] + for kernels in nested_kernels: + # For any ScalarKernel, VecKernel, or Tile2DKernel, + # they should all have the same call_ranges + call_ranges = kernels[0].call_ranges + assert call_ranges is not None + assert all(kernel.call_ranges == call_ranges for kernel in kernels) + kernels_parallel_depth.append( + kernels[0].decide_parallel_depth(len(call_ranges), threads) + ) + return min( + max_parallel_depth, + max(kernels_parallel_depth), + ) + + +class ReasonFusedNodes(Enum): + SAME_VARS_REDUCE = "same_vars_reduce" + COMPATIBLE_REDUCTION = "compatible_reduction" + COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction" + + +class CppScheduling(BaseScheduling): + # ctypes limits the number of args to 1024, refer to: + # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237 + # We set a conservative threshold here. + MAX_FUSED_KERNEL_ARGS_NUM = 500 + backend_features = dict.fromkeys( + [ + BackendFeature.INPLACE_BUFFERS, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + + @classmethod + def get_backend_features(cls, device: torch.device): + return cls.backend_features + + def __init__(self, scheduler): + super().__init__() + self.scheduler = scheduler + if scheduler: + self.reset_kernel_group() + self._ready_to_flush = False + + def _set_flush_status(self, status: bool): + self._ready_to_flush = status + + def group_fn(self, sizes): + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + + def reset_kernel_group(self): + from .cpp_wrapper_cpu import CppWrapperCpu + + self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup] + if isinstance(V.graph.wrapper_code, CppWrapperCpu): + self.kernel_group = CppWrapperKernelGroup() + else: + self.kernel_group = KernelGroup() + + def fuse(self, node1, node2): + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + elif node1.is_template(): + assert not node2.is_template() + return FusedSchedulerNode.fuse(node1, node2) + else: + if ( + self._why_fuse_nodes(node1, node2) + == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + ): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + def get_indexing_ranges_exprs(node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0, node.snodes + var_ranges = None + indexing_exprs = set() + for snode in node.snodes: + v, exprs = get_indexing_ranges_exprs(snode) + if var_ranges is None: + var_ranges = v + assert var_ranges == v, (var_ranges, v, node.snodes) + indexing_exprs.update(exprs) + return var_ranges, list(indexing_exprs) + else: + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_node = node2 if len(vars1) < len(vars2) else node1 + + extra_indexing_constraints = get_indexing_ranges_exprs(ref_node) + + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + assert vars1 == vars2, (vars1, vars2) + return FusedSchedulerNode.fuse(node1, node2) + elif self.can_fuse_vertical_outer_loop(node1, node2): + return OuterLoopFusedSchedulerNode.fuse( + node1, node2, self._get_outer_loop_fusion_depth(node1, node2) + ) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]: + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + if vars1 == vars2 and reduce1 == reduce2: + return ReasonFusedNodes.SAME_VARS_REDUCE + if reduce1 == () and vars1 == vars2 + reduce2: + return ReasonFusedNodes.COMPATIBLE_REDUCTION + if self._can_fuse_nodes_with_compatible_ranges(node1, node2): + return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION + # TODO(jansel): allow fusion pointwise (vars1, ()) suffix? + return None + + def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): + # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges + # e.g. (s0, s1, s2) and (s0 * s1 * s2) + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + + c1 = reduce1 == () and reduce2 == () + c2 = math.prod(vars1) == math.prod(vars2) + c3 = len(vars1) == 1 or len(vars2) == 1 + if not (c1 and c2 and c3): + return False + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + + # We can not recompute sizes and body for nodes other than SchedulerNode + # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode + if isinstance(node_to_recomp, FusedSchedulerNode): + return False + + # It may happen that node1 and node2 compatible number of elements + # but different original ranges, for example: + # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2} + # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details + # TODO: we can fix if it allows us to CSE at least one of the variables + + assert isinstance(node_to_recomp, SchedulerNode) + if isinstance(node_to_recomp.node, ir.TemplateBuffer): + return False + assert isinstance(node_to_recomp.node, ir.ComputedBuffer) + # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges + # but without variable name + ranges2 = node_to_recomp.node.data.get_size() + ranges1 = None + if isinstance(ref_node, FusedSchedulerNode): + ranges_set = set() + for snode in ref_node.snodes: + if isinstance(snode.node, ir.TemplateBuffer): + break + assert isinstance(snode.node, ir.ComputedBuffer) + ranges_set.add(tuple(snode.node.data.get_size())) + + if len(ranges_set) != 1: + return False + + ranges1 = list(next(iter(ranges_set))) + else: + assert isinstance(ref_node, SchedulerNode) + assert isinstance(ref_node.node, ir.ComputedBuffer) + ranges1 = ref_node.node.data.get_size() + + if ranges1 != ranges2: + return False + + return True + + def _can_fuse_horizontal_impl(self, node1, node2): + assert isinstance(node1, (FusedSchedulerNode, SchedulerNode)) + assert isinstance(node2, (FusedSchedulerNode, SchedulerNode)) + if any( + isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2) + ): + return False + return self._why_fuse_nodes(node1, node2) is not None + + def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False + if ( + len(node1.get_nodes()) + len(node2.get_nodes()) + > config.cpp.max_horizontal_fusion_size + ): + return False + + return self._can_fuse_horizontal_impl(node1, node2) + + def _get_outer_loop_fusion_depth(self, node1, node2): + DISABLE_OUTER_LOOP_FUSION = 0 + if not all( + type(node) + in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode) + for node in (node1, node2) + ): + return DISABLE_OUTER_LOOP_FUSION + + _node1 = ( + node1.get_outer_nodes()[-1] + if isinstance(node1, OuterLoopFusedSchedulerNode) + else node1 + ) + assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode)) + _node2 = ( + node2.get_outer_nodes()[0] + if isinstance(node2, OuterLoopFusedSchedulerNode) + else node2 + ) + assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode)) + + _, (vars1, reduce1) = _node1.group + _, (vars2, reduce2) = _node2.group + if vars1 == () and vars2 == () and reduce1 != () and reduce2 != (): + # Reduction only + return DISABLE_OUTER_LOOP_FUSION + if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)): + return ( + node1.outer_loop_fusion_depth + if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth + else DISABLE_OUTER_LOOP_FUSION + ) + outer_loop_fusion_depth = min(len(vars1), len(vars2)) + if ( + outer_loop_fusion_depth >= 1 + and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth] + ): + if any( + type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2) + ): + _compare_node = ( + node1 if type(node1) is OuterLoopFusedSchedulerNode else node2 + ) + if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth: + # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + else: + return DISABLE_OUTER_LOOP_FUSION + else: + # First 2 nodes to generate OuterLoopFusedSchedulerNode + return outer_loop_fusion_depth + return DISABLE_OUTER_LOOP_FUSION + + def can_fuse_vertical_outer_loop(self, node1, node2): + return ( + not node1.is_template() + and not node2.is_template() + and node1.get_operation_names() & node2.ancestors + and not ( + self._can_fuse_horizontal_impl(node1, node2) + and not node1.is_reduction() + ) + and self._get_outer_loop_fusion_depth(node1, node2) >= 1 + ) + + def get_fusion_pair_priority(self, node1, node2): + if self.can_fuse_vertical_outer_loop(node1, node2): + # Outer loop fusion with lower priority + return 1 + else: + return 0 + + def can_fuse_vertical(self, node1, node2): + if node2.is_template(): + # TODO(jgong5): support pre-op fusion with template + return False + if node1.is_template(): + return not node2.is_reduction() + return ( + self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() + ) or self.can_fuse_vertical_outer_loop(node1, node2) + + def try_loop_split(self, nodes: List[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, + we can get the dimension that needs to be split, and the split dimension is contiguous + in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + divide_index_name = None + num_div = 0 + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + num_div += expr.count(FloorDiv) + if num_div > 1: + return nodes + if expr.count(FloorDiv) == 1: + div_expr = expr.find(FloorDiv).pop() + split_var = div_expr.args[0] + split_number = div_expr.args[1] + divide_index_name = name + if ( + isinstance(split_number, sympy.core.numbers.Integer) + and isinstance(split_var, sympy.core.symbol.Symbol) + and split_var in original_body.iter_vars + and divide_index_name is not None + and all( + stride_at_vec_range(expr, split_var) == 1 + for name, expr in original_body.indexing_exprs.items() + if name != divide_index_name + ) + ): + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + + def codegen_outer_loop_node( + self, + node: OuterLoopFusedSchedulerNode, + ): + """ + Generate the code for the outer loop fused scheduler node. + 1. Codegen with fused outer loop: depends on the analysis of + the outer loop fused scheduler node, with or without the local buffer. + 2. If failed, fallback to standard codegen. + """ + kernel_group = self.kernel_group + generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count + cpp_kernel_proxy_list: List[CppKernelProxy] = [] + nodes_list: List[List[SchedulerNode]] = [] + assert isinstance(node, OuterLoopFusedSchedulerNode) + + def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): + """ + Codegen code with fused outer loop and local Buffer. + """ + assert isinstance(node, OuterLoopFusedSchedulerNode) + cpp_kernel_proxy_list.clear() + nodes_list.clear() + + def get_call_ranges(node: BaseSchedulerNode): + assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) + nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + return call_ranges + + local_buffers: List[ir.Buffer] = [] + # Map local buffer name to a list of global buffers + local_to_global_buffers: Dict[str, List[ir.Buffer]] = {} + if all( + len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 + for _node in node.get_outer_nodes() + ): + # Ref to the typical case of local buffer + # in https://github.com/pytorch/pytorch/blob/ + # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 + # where the buffer is with size of last dim and contiguous. + # Only support this typical case at first. + visited_scheduler_nodes: Set[str] = set() + for scheduler_node in node.get_nodes(): + # all users inside same OuterLoopFusedSchedulerNode + assert isinstance(scheduler_node, SchedulerNode) + visited_scheduler_nodes.add(scheduler_node.get_name()) + if ( + scheduler_node.is_reduction() + or len(scheduler_node.get_outputs()) != 1 + ): + continue + + scheduler_buffer = scheduler_node.get_outputs()[0] + if all( + user.node in node.get_nodes() for user in scheduler_buffer.users + ): + global_buffer = scheduler_buffer.node + assert isinstance(global_buffer, ir.ComputedBuffer) + global_buffer_layout = global_buffer.get_layout() + size_offset = node.outer_loop_fusion_depth - len( + get_call_ranges(scheduler_node) + ) + + def is_all_write_read_contiguous(): + contiguous_index_expr = 0 + stride = 1 + for var, range in reversed( + scheduler_node._body.var_ranges.items() + ): + contiguous_index_expr += stride * var + stride *= range + write_index_expr = scheduler_node._body.get_write_expr( + scheduler_buffer.get_name() + ) + + def is_contiguous_index(x): + return x == contiguous_index_expr + + return is_contiguous_index(write_index_expr) and all( + isinstance(user.node, SchedulerNode) + and is_contiguous_index( + user.node._body.get_read_expr( + scheduler_buffer.get_name() + ), + ) + for user in scheduler_buffer.users + ) + + if not ( + global_buffer_layout.is_contiguous() + and is_all_write_read_contiguous() + ): + continue + # Local Buffer is a view of global buffer + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + global_buffer_layout.size[size_offset:], + global_buffer_layout.stride[size_offset:], + ) + + def try_share_local_buffer(local_buffer_layout, local_buffers): + for local_buf in local_buffers: + if local_buffer_layout == local_buf.layout and all( + all( + user.node.get_name() in visited_scheduler_nodes + for user in V.graph.scheduler.name_to_buf[ + global_buffer.name + ].users + ) + for global_buffer in local_to_global_buffers[ + local_buf.name + ] + if global_buffer.name is not None + ): + return local_buf + return None + + local_buf_prefix = "local_buffer_data" + # Share existing local buffer + local_buffer_used = try_share_local_buffer( + local_buffer_layout, local_buffers + ) + if not local_buffer_used: + # Create new local buffer + local_buffer_used = ir.Buffer( + f"{local_buf_prefix}_{len(local_buffers)}", + local_buffer_layout, + ) + local_buffers.append(local_buffer_used) + local_to_global_buffers[local_buffer_used.name] = [] + local_to_global_buffers[local_buffer_used.name].append( + global_buffer, + ) + + with LocalBufferContext(kernel_group.args) as scope: + if len(local_buffers) > 0: + for local_buffer in local_buffers: + assert local_buffer.name is not None + scope.add_local_buffer( + local_buffer, local_to_global_buffers[local_buffer.name] + ) + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] + cpp_kernel_proxy_list.append(cpp_kernel_proxy) + nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] + + if not node.check_outer_fusion_loop_level_attr( + cpp_kernel_proxy_list, node.outer_loop_fusion_depth + ): + return False + metrics.cpp_outer_loop_fused_inner_counts.append( + metrics.CppOuterLoopFusedCount( + len(cpp_kernel_proxy_list), + local_buffer_number=len(scope.local_buffers), + ) + ) + outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( + cpp_kernel_proxy_list, + ) + kernel_group.finalize_kernel( + outer_fusion_cpp_kernel_proxy, + [_node for _nodes in nodes_list for _node in _nodes], + ) + + return True + + if not try_outer_loop_fusion_with_local_buf(node): + # Reset generated_cpp_vec_kernel_count to codegen again + metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + cpp_kernel_proxy_list.clear() + nodes_list.clear() + # Similar as comment in + # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + with torch._inductor.config.patch(inplace_buffers=False): + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(_nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) + + def codegen_node( + self, + node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], + ): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + + if isinstance(node, OuterLoopFusedSchedulerNode): + self.codegen_outer_loop_node(node) + else: + nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) + + args_num = self._get_scheduled_num_args() + if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: + self._set_flush_status(True) + + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template( + template_node + ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: List[Optional[ir.Operation]] = [ + n.node for n in epilogue_nodes + ] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + + def template_buffer_has_other_users( + template_buffer, outputs_by_name, epilogue_nodes + ): + assert template_buffer.get_name() in outputs_by_name + users = outputs_by_name[template_buffer.get_name()].users + return not all( + isinstance(user.node, BaseSchedulerNode) + and user.node.node in epilogue_nodes + for user in users + ) + + flag_template_buffer_has_other_users = template_buffer_has_other_users( + ctb, template_node.outputs_by_name, epilogue_ir_nodes + ) + kernel, render = ctb.make_kernel_render( + ctb, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_ir_nodes, + ) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() # type: ignore[attr-defined] + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() + + def _get_scheduled_num_args(self): + return self.kernel_group.get_num_args() + + def ready_to_flush(self): + return self._ready_to_flush + + def codegen_sync(self): + pass + + def define_kernel(self, src_code, nodes, kernel_args=None): + wrapper = V.graph.wrapper_code + fused_name = ( + get_fused_kernel_name(nodes, config.cpp.descriptive_names) + if config.cpp.descriptive_names + else "" + ) + kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) + kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name) + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "//") + + compile_wrapper = IndentedBuffer() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() + if not V.graph.cpp_wrapper: + compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") + compile_wrapper.splice(src_code, strip=True) + if not V.graph.cpp_wrapper: + compile_wrapper.writeline("''')") + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False) + return kernel_name + + def flush(self): + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() + self._set_flush_status(False) + + +class KernelGroup: + def __init__(self): + super().__init__() + self.args = KernelArgs() + self.loops_code = BracesBuffer() + self.ws = WorkSharing(self.loops_code) + self.stack = contextlib.ExitStack() + self.stack.enter_context(self.ws) + self.scheduled_nodes = [] + + def new_kernel(self, cls, *args): + return cls(self.args, parallel_num_threads(), *args) + + def finalize_kernel(self, new_kernel, nodes): + self.scheduled_nodes += nodes + code = self.loops_code + ws = self.ws + new_kernel.codegen_loops(code, ws) + + def get_num_args(self): + arg_defs, call_args, arg_types = self.args.cpp_argdefs() + args_num = len(arg_defs) + return args_num + + def codegen_group(self, name=None) -> str: + self.stack.close() + if not self.scheduled_nodes: + return "" + code = BracesBuffer() + # 1. Include header files + # TODO: support kernel profile on other platforms + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + code.writelines(["#include "]) + code.writeline(codecache.cpp_prefix()) + + # 2. Function definition + kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name + kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name + arg_defs, _, _ = self.args.cpp_argdefs() + arg_defs = ",\n".ljust(25).join(arg_defs) + func_export_decl = get_export_declaration() + code.writeline( + f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + ) + + # 3. Function body + with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) + for old, new in self.args.aliases(): + code.writeline(f"auto {old} = {new};") + code.splice(self.loops_code) + return code.getvalue() + + def call_kernel(self, wrapper, kernel_name): + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call( + kernel_name, call_args, cuda=False, arg_types=arg_types + ) + + +class CppWrapperKernelGroup(KernelGroup): + def __init__(self): + super().__init__() + self.args = CppWrapperKernelArgs() + + +class WorkSharing: + def __init__(self, code): + self.code = code + self.in_parallel = False + self.num_threads = None + self.stack = contextlib.ExitStack() + + def parallel(self, threads): + if self.in_parallel and threads != self.num_threads: + # wrong number of threads + self.close() + if not self.in_parallel: + self.num_threads = threads + self.in_parallel = True + if config.cpp.dynamic_threads: + self.code.writeline("#pragma omp parallel") + else: + self.code.writeline(f"#pragma omp parallel num_threads({threads})") + self.stack.enter_context(self.code.indent()) + self.code.writeline( + "int tid = omp_get_thread_num();", + ) + + def single(self): + if self.in_parallel: + self.code.writeline("#pragma omp single") + return self.in_parallel + + def close(self): + self.stack.close() + self.in_parallel = False + + def __enter__(self): + self.stack.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stack.__exit__(exc_type, exc_val, exc_tb) + + +@dataclasses.dataclass +class LoopLevel: + var: Optional[sympy.Expr] = None + size: Optional[sympy.Expr] = None + offset: sympy.Expr = sympy.Integer(0) + steps: sympy.Expr = sympy.Integer(1) + parallel: int = 0 + simd_omp: bool = False + simd_vec: bool = False + collapsed: bool = False + is_reduction: bool = False + parent: Optional["LoopLevel"] = None + # the next inner level of the loop, empty if it is inner-most + # contains >1 LoopLevel if the inner level of loop is split + inner: List["LoopLevel"] = dataclasses.field(default_factory=list) + # kernel assigned to this loop level, only valid when it is a leaf + kernel: Optional[CppKernel] = None + + def __post_init__(self): + # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check + # vectorization ISA is a time-consuming and one-shot operation. It leads + # to taking a longer time to import `codegen.cpp` package because the + # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while + # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the + # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation + # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to + # `__post_init__` + picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa() + self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + + def get_kernels(self) -> List[CppKernel]: + """Get all kernel objects under this loop level""" + if self.kernel: + return [self.kernel] + kernels = [] + for loop in self.inner: + kernels += loop.get_kernels() + return kernels + + def get_root(self): + """Get all kernel objects under this loop level""" + root = self + while root.parent: + root = root.parent + return root + + def set_kernel(self, kernel: CppKernel): + """ + Set the kernel under this loop level. No split is allowed under + this loop level. + """ + if not self.inner: + self.kernel = kernel + loop: Optional[LoopLevel] = self + assert loop is not None + return + assert len(self.inner) == 1 + self.inner[0].set_kernel(kernel) + + def get_loops_at(self, depth) -> List["LoopLevel"]: + if depth == 0: + return [self] + else: + loops = [] + for loop in self.inner: + loops += loop.get_loops_at(depth - 1) + return loops + + def split_with_tiling(self, depth, factor): + def clone_inner(): + inner = [] + if self.inner: + for loop in self.inner: + inner.append(loop.clone()) + return inner + + def do_split_with_tiling(): + sympy_factor = sympy.Integer(factor) + + offset = FloorDiv(self.size, sympy_factor) * sympy_factor + main_loop = LoopLevel(self.var, offset) + main_loop.steps = sympy_factor + main_loop.parallel = self.parallel + main_loop.collapsed = False + main_loop.is_reduction = self.is_reduction + main_loop.inner = clone_inner() + if main_loop.inner: + for loop in main_loop.inner: + loop.parent = main_loop + + tail_loop = LoopLevel(self.var, self.size) + tail_loop.offset = offset + tail_loop.parallel = self.parallel + tail_loop.collapsed = False + tail_loop.is_reduction = self.is_reduction + tail_loop.inner = clone_inner() + if tail_loop.inner: + for loop in tail_loop.inner: + loop.parent = tail_loop + + return main_loop, tail_loop + + if depth == 0: + main_loop, tail_loop = do_split_with_tiling() + parent = self.parent + if parent: + parent.inner = [main_loop, tail_loop] + main_loop.parent = parent + tail_loop.parent = parent + return main_loop, tail_loop + else: + assert len(self.inner) == 1 + return self.inner[0].split_with_tiling(depth - 1, factor) + + def clone(self): + loop = copy(self) + loop.inner = [] + if self.inner: + for inner_loop in self.inner: + inner_loop_clone = inner_loop.clone() + inner_loop_clone.parent = loop + loop.inner.append(inner_loop_clone) + loop.kernel = deepcopy(self.kernel) + return loop + + def lines(self): + offset_expr = cexpr_index(self.offset) + size_expr = cexpr_index(self.size) + if config.cpp.no_redundant_loops and offset_expr == size_expr: + return None + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) + if self.parallel: + # TODO(jansel): look into chunk size and other schedules + line1 = "#pragma omp for" + if self.parallel > 1: + line1 += f" collapse({self.parallel})" + if self.simd_omp: + line1 = line1.replace(" for ", f" for {simd}") + elif self.simd_vec: + line1 = "" + elif self.simd_omp: + line1 = f"#pragma omp {simd}" + elif not self.is_reduction and cpp_builder.is_gcc(): + line1 = "#pragma GCC ivdep" + else: + line1 = "" + offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" + size_str = f"{self.var}<{size_expr}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) + line2 = f"for({offset_str}; {size_str}; {steps_str})" + if self.collapsed or not line1: + return [line2] + return [line1, line2] + + +@dataclasses.dataclass +class LoopNestWithSplit: + """ + A loop-nest like structure but with some loop level split along + the loop range into the main tiling loop and the tail. It is built + with the `build` method as a loop nest and then split with + `split_with_tiling` at some depth. + + A typical case is for vectorization where we typically split at the inner-most + loop level. A more complicated case is 2D tiling where we split at + both inner-most and outer levels. + """ + + root: Optional[List[LoopLevel]] = None + kernel: Optional[CppKernel] = None + + @staticmethod + def build(kernel: CppKernel): + """Build a LoopNest with the given `kernel` as the leaf""" + itervars = kernel.itervars + ranges = kernel.ranges + reduction_depth = kernel.reduction_depth + assert reduction_depth is not None + + root: List[LoopLevel] = [] + levels: List[LoopLevel] = root + loop: Optional[LoopLevel] = None + for loop_idx, (var, size) in enumerate(zip(itervars, ranges)): + loop = LoopLevel(var, size, parent=loop) + if loop_idx >= reduction_depth: + loop.is_reduction = kernel.is_reduction + levels.append(loop) + levels = loop.inner + loop_nest = LoopNestWithSplit(root) + if loop: + loop.kernel = kernel + else: + loop_nest.kernel = kernel + return loop_nest + + def __bool__(self): + return bool(self.root) + + def get_loops_at(self, depth) -> List[LoopLevel]: + """Get all the loop levels at the given `depth` (most outer loop has depth 0)""" + loops: List[LoopLevel] = [] + assert self.root is not None + for loop in self.root: + loops += loop.get_loops_at(depth) + return loops + + @cache_on_self + def max_parallel_depth(self): + """ + Maximal allowed depth for parallelism: + 1) Levels without splitting and + 2) All reduction or non-reduction levels + When the loop is split at the top level, the max depth is 1. + """ + max_depth = 0 + assert self.root is not None + loops = self.root + if len(loops) > 1: + return 1 + is_reduction = loops[0].is_reduction if loops else False + while len(loops) == 1 and loops[0].is_reduction == is_reduction: + max_depth += 1 + loops = loops[0].inner + return max_depth + + def is_reduction_only(self): + """ + Whether all the loops are for reduction. Reduction loops + are always the inner most ones. + """ + return ( + self.root is not None and len(self.root) > 0 and self.root[0].is_reduction + ) + + def mark_parallel(self, par_depth): + assert ( + par_depth <= self.max_parallel_depth() + ), "Parallel depth cannot exceed the maximal allowed parallel depth" + assert self.root is not None + loops = self.root + for loop in loops: + loop.parallel = par_depth + for i in range(1, par_depth): + loops = loops[0].inner + loops[0].collapsed = True + + def split_with_tiling(self, depth, factor): + """ + Split the loop into main and tail loops at given `depth` so that the range + of the main loop has range `floor_div(range, factor) * factor` and + the tail loop handles the remainder. The main loop is tiled + according to the `factor`. + """ + loops = self.get_loops_at(depth) + assert len(loops) == 1 + split_loops = loops[0].split_with_tiling(0, factor) + if depth == 0: + self.root = split_loops + return split_loops + + def get_kernels(self) -> List[CppKernel]: + """Get all kernel objects under this loop nest""" + if self.kernel: + return [self.kernel] + kernels: List[CppKernel] = [] + assert self.root is not None + for loop in self.root: + kernels += loop.get_kernels() + return kernels diff --git a/.venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py b/.venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..43fe23f68469c23bec04d22adebf3757c4ca3192 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,1043 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +from functools import lru_cache +from typing import Any, Callable, cast, List, Optional, Set, Union +from unittest.mock import patch + +import torch +import torch.utils + +from ..._dynamo.utils import counters +from .. import config, ir, lowering as L +from ..kernel.mm_common import mm_args +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import cache_on_self, has_free_symbols, parallel_num_threads +from ..virtualized import ops, V +from .cpp import get_export_declaration +from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType +from .cpp_template import CppTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import ( + create_epilogue_with_attr, + DTYPE_TO_CPP, + GemmBlocking, + get_gemm_template_output_and_compute_dtype, +) + + +log = logging.getLogger(__name__) + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} + +{{micro_gemm.codegen_define(kernel)}} + +{%- if x_scale is not none %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %} +{%- else %} + {%- set kernel_args = {"X": X, "W": W, "inp": inp} %} +{%- endif %} + +extern "C" {{export_declaration}} +{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}} +{ + {{kernel.maybe_codegen_profile()}} + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{N}}; + constexpr int64_t K = {{K}}; + constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; + constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; + +{%- if is_dynamic_M %} + const int64_t M = {{kernel.size(GemmOut, 0)}}; + const int64_t Mr_blocks = (M + Mr - 1) / Mr; + {%- if num_threads > 1 %} + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + {%- else %} + const auto Mt_blocks = Mr_blocks; + const auto Nt_blocks = Nr_blocks; + const auto Kt_blocks = Kr_blocks; + {%- endif %} + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); + const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- else %} + constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; + constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; + constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; + constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}}; + constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}}; + constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; + constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; + constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; +{%- endif %} + + // make sure all partitions are assigned + {{kernel.assert_function}}( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks, + "Not all partitions are assigned." + ); + +{%- if maybe_k_slicing %} + std::unique_ptr[]> local_buf_ptrs; + if (num_k_slices > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]); + } +{%- endif %} + +{%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + const int tid = omp_get_thread_num(); + int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; + mm_get_thread_blocks( + tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks, + m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); + {%- if maybe_k_slicing %} + const int64_t k_group_id = tid / num_k_slices; + const int64_t k_slice_id = tid % num_k_slices; + {%- endif %} +{%- else %} + { + const int tid = 0; + const int64_t m_block_start = 0; + const int64_t m_block_end = Mr_blocks; + const int64_t n_block_start = 0; + const int64_t n_block_end = Nr_blocks; + const int64_t k_block_start = 0; + const int64_t k_block_end = Kr_blocks; +{%- endif %} + {{ micro_gemm.codegen_init(kernel) }} +{%- if use_local_acc %} + {%- set acc_buf_name = "local_acc_buf" %} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} +{%- endif %} + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + const int64_t m_start = mc * Mr; + const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size = m_end - m_start; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + // NB: assume we pad N, nc_block_end won't exceed padded N here. + const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); +{%- if use_local_acc %} + {%- set acc = kernel.local_buffers[acc_buf_name] %} + {{ kernel.reinit_buffer_if_null(acc_buf_name) }} +{%- else %} + {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- endif %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * Kr; + int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); +{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + for (int64_t nci = nc; nci < nc_block_end; nci++) { +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} +{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + if (kc == k_block_start) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }} + } + } + } +{%- if maybe_k_slicing %} + if (num_k_slices > 1) { + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }}); + } else +{%- endif %} + { +{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} + {{ kernel.store_output( + tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- if maybe_k_slicing %} + if (num_k_slices > 1) { + #pragma omp barrier + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + // We slice M-dim and each thread in the k-slicing group works on a slice + const int64_t m_start_unsliced = mc * Mr; + const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; + const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices; + const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); + const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); + const int64_t m_size = m_end - m_start; + const int64_t m_offset = m_start - m_start_unsliced; + for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { + const int64_t n_start = nc * Nr; + const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); + const int64_t n_size = n_end - n_start; + const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get(); + for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get(); + for (int64_t m = m_offset; m < m_offset + m_size; m++) { + #pragma omp simd + for (int64_t n = 0; n < n_size; n++) { + {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n]; + } + } + } + {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %} + {{ kernel.store_output( + tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers + )|indent(20, false) + }} + } + } + } +{%- endif %} + {{ micro_gemm.codegen_finalize(kernel) }} + } +} +""" + + +def get_padded_n(n, block_n): + return (n + block_n - 1) // block_n * block_n + + +class CppPackedGemmTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + has_bias=False, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ) -> None: + assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8] + super().__init__( + "packed_gemm", + input_nodes, + layout, + num_threads, + epilogue_creator=epilogue_creator, + ) + self.beta = beta + self.alpha = alpha + self.has_bias = has_bias + self.register_blocking = register_blocking + m, n = layout.size + _, k = input_nodes[0].get_size() + self.m, self.n, self.k = m, n, k + self.padded_n = get_padded_n(n, self.register_blocking.block_n) + self.is_dynamic_M = has_free_symbols((m,)) + + @cache_on_self + def thread_blocking(self) -> GemmBlocking: + """ + NOTE [Thread blocking in Cpp GEMM] + We use simple heuristics to decide the thread blocking: + 1. Make sure all threads are occupied as much as possible. + 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse. + 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing. + TODO(jgong5): allow tuning various blocking options + """ + + @lru_cache(maxsize=100) + def get_factors(number): + factors = [] + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks): + thread_block_k = math.ceil(k_blocks / k_factor) + thread_block_n = math.ceil(n_blocks / n_factor) + thread_block_m = math.ceil(m_blocks / m_factor) + return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) + + assert ( + not self.is_dynamic_M + ), "Unable to determine thread blocking for dynamic M." + register_blocking = self.register_blocking + m_blocks = math.ceil(self.m / register_blocking.block_m) + n_blocks = math.ceil(self.n / register_blocking.block_n) + k_blocks = math.ceil(self.k / register_blocking.block_k) + factors = get_factors(self.num_threads) + assert len(factors) > 0 + + if config.cpp.gemm_thread_factors is not None: + factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")] + assert len(factors) == 3 + assert math.prod(factors) == self.num_threads + return get_blocking( + factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks + ) + + # we favor square-sized thread blocks for good data reuse + def get_better_blocking(blocking, best_blocking): + if best_blocking is None: + best_blocking = blocking + else: + block_m_size = blocking.block_m * register_blocking.block_m + block_n_size = blocking.block_n * register_blocking.block_n + best_block_m_size = best_blocking.block_m * register_blocking.block_m + best_block_n_size = best_blocking.block_n * register_blocking.block_n + if blocking.block_k > best_blocking.block_k: + best_blocking = blocking + elif ( + blocking.block_k == best_blocking.block_k + and block_m_size + block_n_size + < best_block_m_size + best_block_n_size + ): + best_blocking = blocking + return best_blocking + + best_blocking = None + # check if we can have a thread-blocking to occupy all threads without k-slicing + for n_factor in factors: + m_factor = self.num_threads // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for k_factor in factors: + if k_blocks >= k_factor and ( + config.cpp.gemm_max_k_slices == 0 + or k_factor <= config.cpp.gemm_max_k_slices + ): + n_factors = get_factors(self.num_threads // k_factor) + for n_factor in n_factors: + m_factor = (self.num_threads // k_factor) // n_factor + if n_blocks >= n_factor and m_blocks >= m_factor: + blocking = get_blocking( + m_factor, + n_factor, + k_factor, + m_blocks, + n_blocks, + k_blocks, + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + if best_blocking is None: + for n_factor in factors: + m_factor = self.num_threads // n_factor + if n_blocks >= n_factor or m_blocks >= m_factor: + blocking = get_blocking( + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks + ) + best_blocking = get_better_blocking(blocking, best_blocking) + + assert best_blocking is not None + return best_blocking + + @cache_on_self + def cache_blocking(self) -> GemmBlocking: + def get_cache_blocking(register_blocking, thread_blocking): + Mr = register_blocking.block_m + Nr = register_blocking.block_n + Kr = register_blocking.block_k + + Mt_blocks = thread_blocking.block_m + Nt_blocks = thread_blocking.block_n + Kt_blocks = thread_blocking.block_k + + if config.cpp.gemm_cache_blocking is not None: + blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")] + assert len(blockings) == 3 + Mc_blocks, Nc_blocks, Kc_blocks = blockings + return ( + min(Mc_blocks, Mt_blocks), + min(Nc_blocks, Nt_blocks), + min(Kc_blocks, Kt_blocks), + ) + + # The ratios below are empirically determined to decide + # the effective sizes of L1 and L2. + # TODO: tune the factor here + L1_limit_factor = 0.8 + L2_limit_factor = 0.5 + + L1_cache_size = ( + torch._C._cpu._L1d_cache_size() + ) # per core cache size in Bytes + assert ( + L1_cache_size > 0 + ), f"Expect L1_cache_size > 0 but got {L1_cache_size}" + L1 = L1_cache_size * L1_limit_factor + + L2_cache_size = ( + torch._C._cpu._L2_cache_size() + ) # per core cache size in Bytes + assert ( + L2_cache_size > 0 + ), f"Expect L2_cache_size > 0 but got {L2_cache_size}" + L2 = L2_cache_size * L2_limit_factor + + def get_num_byte(dtype): + return torch.tensor([], dtype=dtype).element_size() + + num_byte_A = get_num_byte(self.input_nodes[0].get_dtype()) + num_byte_B = get_num_byte(self.input_nodes[1].get_dtype()) + + # NOTE [CPP GEMM Cache Blocking Algorithm] + # Our overall strategy is to + # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc. + # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to + # decide Kc. We want to make Mc large enough to better reuse B. + # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A + # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc. + + # Step 1: Decide Kc assuming B block is L1-reside. + size_cache_B = Kr * Kt_blocks * Nr * num_byte_B + Kc_blocks = Kt_blocks + if size_cache_B > L1: + Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B)) + + # Step 2: Decide Mc assuming A block is L2-reside. + min_Mc_ratio = 2 # TODO(jgong5): something to tune? + min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr) + assert min_Mc_blocks >= 1 + Kt_bytes = Kt_blocks * Kr * num_byte_A + if min_Mc_blocks * Mr * Kt_bytes < L2: + # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt + # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks) + # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside + # in L1. + Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes))) + Nc_blocks = 1 + else: + # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse + # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2. + Mc_blocks = Mt_blocks + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32 + Kc_bytes = Kc_blocks * Kr * num_byte_A + if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2: + # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2, + # assuming Mc == Nc for good data reuse. + M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8 + if M_max < Mc_blocks * Mr: + Mc_blocks = math.floor(M_max / Mr) + Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks) + + return Mc_blocks, Nc_blocks, Kc_blocks + + assert ( + not self.is_dynamic_M + ), "Unable to determine cache blocking for dynamic M." + register_blocking = self.register_blocking + thread_blocking = self.thread_blocking() + + return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking)) + + def log_blockings(self): + log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004 + if self.is_dynamic_M: + # thread and cache blockings are determined at runtime for dynamic shapes + return + log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004 + thread_blocking = self.thread_blocking() + log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004 + + def get_occupancy(): + m_blocks = math.ceil(self.m / self.register_blocking.block_m) + n_blocks = math.ceil(self.n / self.register_blocking.block_n) + k_blocks = math.ceil(self.k / self.register_blocking.block_k) + m = math.ceil(m_blocks / thread_blocking.block_m) + n = math.ceil(n_blocks / thread_blocking.block_n) + k = math.ceil(k_blocks / thread_blocking.block_k) + return (m, n, k) + + log.debug( + f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004 + ) + + def maybe_k_slicing(self): + if self.num_threads == 1: + return False + if self.is_dynamic_M: + # TODO(jgong5): perhaps use size hint to decide? + return True + register_blocking = self.register_blocking + k_blocks = math.ceil(self.k / register_blocking.block_k) + thread_blocking = self.thread_blocking() + return k_blocks > thread_blocking.block_k + + @staticmethod + def add_choices( + choices, + layout, + input_nodes, + beta=1, + alpha=1, + has_bias=False, + trans_w=False, + input_indices=None, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ): + if input_indices is None: + input_indices = list(range(len(input_nodes))) + + def reorder_and_filter(inputs, layout_or_out): + if has_bias: + assert len(input_indices) >= 3 + # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [ + inputs[x_idx], + inputs[w_idx], + inputs[inp_idx], + *[inputs[idx] for idx in input_indices[3:]], + ], layout_or_out + else: + assert len(input_indices) >= 2 + return [inputs[idx] for idx in input_indices], layout_or_out + + def maybe_to_dense(inputs, layout_or_out): + new_inputs = list(inputs) + if isinstance(inputs[1], torch.Tensor): + W = inputs[1] + new_inputs[1] = W.to_dense() if W.is_mkldnn else W + return new_inputs, layout_or_out + + def normalize_shapes(inputs, layout_or_out): + if not trans_w: + return inputs, layout_or_out + new_inputs = list(inputs) + X = inputs[0] + W = inputs[1] + B = inputs[2] if has_bias else None + if isinstance(W, ir.IRNode): + if trans_w: + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + W = L.permute(W, [1, 0]) + else: + if trans_w: + assert isinstance(W, torch.Tensor) + W = W.transpose(0, 1) + if B is not None: + if isinstance(B, ir.IRNode): + if not isinstance(B, ir.TensorBox): + B = ir.TensorBox(B) + B = L.expand(B, (X.get_size()[0], B.get_size()[-1])) + else: + assert isinstance(B, torch.Tensor) + B = B.expand(X.shape[0], B.shape[-1]) + new_inputs[1] = W + if B is not None: + new_inputs[2] = B + return new_inputs, layout_or_out + + # TODO(jgong5): decide proper number of threads per problem size + num_threads = parallel_num_threads() + new_inputs, _ = normalize_shapes( + *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) + ) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + new_inputs[0].get_dtype() + ) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=new_inputs[0].get_dtype(), + input2_dtype=new_inputs[1].get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=alpha, + num_threads=num_threads, + ) + assert micro_gemm is not None + _, block_n, _ = micro_gemm.register_blocking + padded_n = get_padded_n(n, block_n) + + def pack_weight(inputs, layout_or_out): + W = inputs[1] + new_inputs = list(inputs) + blocked_w: Union[ir.IRNode, torch.Tensor] = W + if isinstance(W, ir.IRNode): + new_size = [padded_n // block_n, k, block_n] + blocked_w = ir.Buffer( + W.get_name(), # Borrow the registered buffer name + ir.FixedLayout( + W.get_device(), + W.get_dtype(), + new_size, + ir.FlexibleLayout.contiguous_strides(new_size), + 0, + ), + ) + else: + blocked_w = ( + torch.nn.functional.pad(W, (0, padded_n - n)) + .reshape(k, padded_n // block_n, block_n) + .transpose(0, 1) + .contiguous() + ) + if micro_gemm.get_b_layout() != LayoutType.NORMAL: + layout_str = ( + "VNNI4" + if micro_gemm.get_b_layout() == LayoutType.VNNI4 + else "VNNI2" + ) + assert micro_gemm.get_b_layout() in [ + LayoutType.VNNI2, + LayoutType.VNNI4, + ], f"We only support {layout_str} for now" + vnni_size = ( + 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 + ) + assert ( + k % vnni_size == 0 + ), f"k should be divisible by vnni_size for {layout_str} layout" + blocked_w = ( + blocked_w.view( + padded_n // block_n, k // vnni_size, vnni_size, block_n + ) + .transpose(-1, -2) + .contiguous() + .view(padded_n // block_n, k, block_n) + ) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(blocked_w.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) + new_inputs[1] = blocked_w + + def _is_int8_gemm(inputs): + return ( + isinstance(inputs[0], ir.IRNode) + and inputs[0].get_dtype() == torch.uint8 + ) or ( + isinstance(inputs[0], torch.Tensor) + and inputs[0].dtype == torch.uint8 + ) + + if _is_int8_gemm(new_inputs): + BCompensate = None + if isinstance(W, ir.IRNode): + BCompensate = V.graph.add_tensor_constant( + V.graph.constants[W.get_name() + "_BMatrixCompens"], + W.get_name() + "_BMatrixCompens", + ) + else: + BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment] + new_inputs.append(BCompensate) + return new_inputs, layout_or_out + + def preprocessor(inputs, layout): + return pack_weight( + *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) + ) + + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + + W_node = new_input_nodes[1] + assert W_node.get_name() in V.graph.constants + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, _ = pack_weight( + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + ) + + # By using the new packed weight for the GEMM template, we can prune the + # old weight if it has no other users. This saves memory but makes the FX graph + # non-retraceable. To support retracing, we can add a repack node to the + # FX graph. For example: + # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + W_tensor_users = 0 + for node in reversed(V.graph.graph.nodes): + # Case may happen when the wgt tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.name + ): # wgt might already be deleted + comp_tensor = getattr(V.graph.module, node.name) + if ( + W.is_mkldnn == comp_tensor.is_mkldnn + and W.dtype == comp_tensor.dtype + and W.device == comp_tensor.device + and ( + ( + not W.is_mkldnn + and ( + W.untyped_storage().data_ptr() + == comp_tensor.untyped_storage().data_ptr() + ) + ) + or ( + W.is_mkldnn + and ( + torch.ops.mkldnn.data_ptr(W) + == torch.ops.mkldnn.data_ptr(comp_tensor) + ) + ) + ) + ): + W_tensor_users += 1 + + for node in reversed(V.graph.graph.nodes): + # The wgt tensor has been used by only 1 get_attr node + # The get_attr node has only 1 user fx node + if ( + node.name == W_node.get_name() + and len(node.users) == 1 + and W_tensor_users == 1 + ): + del V.graph.constants[node.name] + delattr(V.graph.module, node.name) + delattr(V.graph.graph.owning_module, node.name) + + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + + template = DataProcessorTemplateWrapper( + CppPackedGemmTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + has_bias=has_bias, + epilogue_creator=epilogue_creator, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override,return] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + flag_template_buffer_has_other_users: Optional[bool] = None, + epilogue_nodes: Optional[List[ir.IRNode]] = None, + **kwargs, + ) -> str: + assert len(self.input_nodes) >= 2 + + int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8 + x_scale = None + x_zp = None + w_scale = None + w_zp = None + if int8_gemm: + X, W = self.input_nodes[0], self.input_nodes[1] + bias_idx = 2 if self.has_bias else 1 + inp = self.input_nodes[bias_idx] if self.has_bias else None + x_scale = self.input_nodes[bias_idx + 1] + x_zp = self.input_nodes[bias_idx + 2] + w_scale = self.input_nodes[bias_idx + 3] + w_zp = self.input_nodes[bias_idx + 4] + Y = self.output_node + else: + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + inp = self.input_nodes[2] if self.has_bias else None + + template_buffer_has_other_users = None + + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + + assert flag_template_buffer_has_other_users is not None + template_buffer_has_other_users = flag_template_buffer_has_other_users + + template_buffer = Y + gemm_output_buffer = template_buffer + + epilogues: List[ir.IRNode] = [] + reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = [] + epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = [] + fake_buffers: List[ir.Buffer] = [] + Y_aliases: Set[str] = set() + + use_local_acc = ( + self.layout.dtype != torch.float + or template_buffer_has_other_users + or int8_gemm + or self.padded_n != self.n + or self.maybe_k_slicing() + ) + + # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template, + # but we'd better move it here to align with fp. + if inp is not None and self.beta != 0 and not int8_gemm: + # add an epilogue for bias add + def _bias_add_epilogue(buf): + return create_epilogue_with_attr( + buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype + ) + + epilogue_creators.append(_bias_add_epilogue) + + if self.epilogue_creator is not None: + epilogue_creators.append(self.epilogue_creator) + + # When the GEMM output buffer is localized but it has users other than the epilogue nodes, + # we need to copy the value in the GEMM output local buffer to a global buffer. + def need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + # The GEMM output buffer is a global buffer, thus copy is not needed. + if not use_local_acc: + return False + + # The possible value of template_buffer_has_other_users is (None, False, True) + # It is None when generating the gemm template during autotune and it will have value during scheduler codegen. + # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases: + # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune) + # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the + # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value). + if not template_buffer_has_other_users: + return False + + # When bias is not None or self.epilogue_creator is not None, + # there will be epilogue_creators after the GEMM. + # The GEMM output buffer is localized while + # the output buffer of the epilogue_creators is a global buffer. + if epilogue_creators: + return False + + return True + + if need_copy_from_local_to_global_buffer_epilogue( + use_local_acc, template_buffer_has_other_users, epilogue_creators + ): + + def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer): + dtype = self.layout.dtype + input_loader = input_buffer.make_loader() + + def copy_inner(index): + input = input_loader(index) + result = ops.to_dtype(input, dtype) + return result + + return ir.Pointwise( + device=input_buffer.get_device(), + dtype=self.layout.dtype, + inner_fn=copy_inner, + ranges=input_buffer.get_size(), + ) + + epilogue_creators.append(copy_from_local_to_global_buffer_epilogue) + + # NOTE [How CPP GEMM template epilogues are organized] + # gemm_output_buffer + # --> zero or more in-template epilogues (created by `epilogue_creators`) --> + # template_buffer + # --> zero or more out-of-template epilogues (`epilogue_nodes`) --> + # Y + if epilogue_creators: + gemm_output_name = "buf_GemmOut" + gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout) + current_input_buffer = gemm_output_buffer + for i, creator in enumerate(epilogue_creators): + if i == len(epilogue_creators) - 1: + buffer_name = template_buffer.get_name() + else: + buffer_name = f"buf_GemmOut_epilogue_{i}" + epilogues.append( + ir.ComputedBuffer( + name=buffer_name, + layout=template_buffer.layout, + data=creator(current_input_buffer), + ) + ) + fake_buffers.append(current_input_buffer) + Y_aliases.add(current_input_buffer.get_name()) + reindexers.append(None) + if i < len(epilogue_creators) - 1: + current_input_buffer = ir.Buffer( + buffer_name, template_buffer.layout + ) + + Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y + + if epilogue_nodes: + epilogues.extend(epilogue_nodes) + assert Y.get_numel() == epilogues[-1].get_numel() + Y = cast(ir.Buffer, epilogues[-1]) + + if not template_buffer_has_other_users: + Y_aliases.add(template_buffer.get_name()) + + if ( + Y.get_size() == template_buffer.get_size() + and Y.get_stride() == template_buffer.get_stride() + ): + reindexers.extend([None] * len(epilogue_nodes)) + Y_2d = Y + else: + + def get_reindexer(epilogue_node): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) + + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) + return reindexer + + reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item] + if isinstance(Y, ir.BaseView): + storage = ir.StorageBox(Y.unwrap_view()) + else: + assert isinstance(Y, ir.Buffer) + storage = ir.StorageBox(Y) + Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout()) + + output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( + X.get_dtype() + ) + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + input_dtype=X.get_dtype(), + input2_dtype=W.get_dtype(), + output_dtype=output_dtype, + compute_dtype=compute_dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert micro_gemm is not None + assert self.register_blocking == micro_gemm.register_blocking + self.log_blockings() + if isinstance(micro_gemm, CppMicroGemmAMX): + counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + N=self.n, + K=self.k, + PADDED_N=self.padded_n, + GemmOut=gemm_output_buffer, + aliases={alias: Y.get_name() for alias in Y_aliases}, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + export_declaration=get_export_declaration(), + epilogue_nodes=epilogues, + reindexers=reindexers, + Y_2d=Y_2d, + use_local_acc=use_local_acc, + maybe_k_slicing=self.maybe_k_slicing(), + x_scale=x_scale, + x_zp=x_zp, + w_scale=w_scale, + w_zp=w_zp, + acc_buf_dtype=torch.int32 if int8_gemm else torch.float, + DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, + ) + with contextlib.ExitStack() as stack: + for buf in fake_buffers: + stack.enter_context( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9369856b56d4baab429cd3580b472d6ec855e7 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,173 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..519fb24c9d4fdded9069ceefbc0a4690234c9bba --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..123a44c9e02ac4b457de25569d8b4971e2c9fbfd --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,203 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..ff83141ef7b3c739820eec7c26fef8b5495c0868 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,219 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..ce3e0367162157dfb666fd050cba6bdd57518080 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,129 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +neg_default = CallFunction(aten.neg.default, div_Tensor) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..e764711f57e5c233e909f213a67fc37920af3be4 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,209 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6f8194f3ca8cb1785f3b3bbb0c3f433619e280 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,227 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..37bed13ead671cef227902ef9ab9238b4d67985a --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,598 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..e375820d22ee06b43ca358b363343ba3677d1e6a --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,243 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py new file mode 100644 index 0000000000000000000000000000000000000000..aa191a007dc42cd5ab048dd064a2fb8938bab5b2 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -0,0 +1,452 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py new file mode 100644 index 0000000000000000000000000000000000000000..383bd11086aee14a04af89729d0cb69644426518 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -0,0 +1,208 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb2aa26a5f757d4482a0a03a477210505e0e099 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,173 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..7906861b2d05a392dcd4ef1260db1ad6e287dcb2 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,189 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..0a81a9d92fb49ad692ed1c3fbfa1c39ab70b4b0b --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,189 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..a77cd612d1b073ff0465ec5d11beb96c86623c1a --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,177 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..6d52ac0029922b1adecc6c2fa3ede31d9c54133f --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,193 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..9c28754601c2336e2af5a7310992f073c5be8232 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..ad36767fc17afdedd60c399f369baf2df081044d --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..901799252c1778ae8a3bf63432bade9117a2e628 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..60052fdcce650c7a16da2b9a6fd4906fa4c61c6a --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py @@ -0,0 +1,52 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta')) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True) +view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha')) +addmm_pattern_training = MultiOutputPattern([addmm_default, + view_default, + mul_Scalar_1, + mul_Scalar_2, + None, + None +]) + + +addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..40d4a36d6063a26532da652bfae2d88971aaeeed --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py @@ -0,0 +1,44 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1')) +bmm_pattern_training = MultiOutputPattern([bmm_default, + bmm_default_1, + bmm_default_2 +]) + + +bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..8044f66db88aed96acbc75a95c1b15df77ca8827 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py @@ -0,0 +1,44 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mm_pattern_training = MultiOutputPattern([mm_default, + mm_default_1, + mm_default_2 +]) + + +mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/.venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py b/.venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e10c01113edb5440eb3222dd0b9c0868c7259c --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,2514 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + PatternMatcherPass, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS + + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = Tuple[ + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], +] +_Range: TypeAlias = Tuple[int, int] + + +PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {} +POST_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {} + +pre_grad_pass_names = [ + "normalization_pass", + "remove_split_with_size_one_pass", + "merge_getitem_cat_pass", + "merge_stack_tahn_unbind_pass", + "merge_splits_pass", + "mutate_cat_pass", + "split_cat_pass", + "unbind_stack_pass", + "split_cat_to_slices_pass", + "unbind_cat_to_view_pass", + "split_stack_to_cats_pass", + "unbind_stack_to_slices_pass", + "move_reshape_out_of_split_stack_pass", +] + +post_grad_pass_names = [ + "normalization_aten_pass", + "decompose_mm_pass", + "unbind_stack_aten_pass", + "shape_padding_multiplier", +] + +for pass_name in pre_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in PRE_GRAD_FUSIONS: + continue + PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + +for pass_name in post_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in POST_GRAD_FUSIONS: + continue + POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + + +def construct_pattern_matcher_pass(pass_name: str): + """ + Return the specific pattern_matcher_pass given the pass name. + """ + if pass_name in PRE_GRAD_PATTERNS: + return PRE_GRAD_PATTERNS[pass_name] + else: + return POST_GRAD_PATTERNS[pass_name] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +def _get_dim(node: Any): + assert isinstance(node, torch.fx.Node) + if "dim" in node.kwargs: + assert isinstance(node.kwargs["dim"], int) + return node.kwargs["dim"] + if node.target == torch.unbind: + if len(node.args) == 2: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + if node.target == torch.split: + if len(node.args) == 3: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + raise AssertionError( + f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}" + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +def remove_split_with_size_one(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + if len(split_sections) == 1: + # find the grand children of the split_node + next_users = find_next_users(split_node) + user = next(iter(split_node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, split_input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(split_node) + counters["inductor"]["remove_split_with_size_one_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if not is_node_meta_valid(input): + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + new_args = (tensors,) + new_kwargs = {"dim": cat_dim} + if ( + cat_node.args == new_args + and cat_node.kwargs == new_kwargs + and cat_node.op == "call_function" + ): + return + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + kwargs=new_kwargs, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, # type: ignore[arg-type] + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users.keys(): + for getitem_user in getitem_node.users.keys(): + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + new_squeeze_node.meta.update(squeeze_node.meta) + match.graph.erase_node(squeeze_node) + + +@register_graph_pattern( + CallMethodVarArgs("reshape", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_reshape_default(match: Match, *args, **kwargs): + reshape_node = match.nodes[0] + if not is_node_meta_valid(reshape_node): + log.debug("example value absent for node: %s", reshape_node) + return + reshape_input = get_arg_value(reshape_node, 0) + + with match.graph.inserting_after(reshape_node): + new_reshape_node = match.graph.call_function( + torch.reshape, + args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)), + ) + reshape_node.replace_all_uses_with(new_reshape_node) + new_reshape_node.meta.update(reshape_node.meta) + match.graph.erase_node(reshape_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split) -> None: + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = set() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=construct_pattern_matcher_pass("merge_splits_pass"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: List[int], + next_split_sections: List[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = _get_dim(first_split) + + to_remove = [] + + with graph.inserting_before(first_split): # type: ignore[arg-type] + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + if is_node_meta_valid(first_split_input): + new_split.meta["example_value"] = torch.split( + first_split_input.meta["example_value"], + new_split_sections, + dim=first_split_dim, + ) + first_split_num_to_user = { + user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = { + user.args[1]: user for user in node.users.keys() + } + # It is not necessary all getitems from the split node are used. + # We use the num of users to check the getitems to be merged. + for next_split_num in range(len(node.users.keys())): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters["inductor"]["merge_splits_pass"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + counters["inductor"]["unbind_stack_pass"] += 1 + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: List[torch.fx.Node] + ) -> List[List[Union[torch.fx.Node, _Range]]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = [] + for user in next_users: + if user.target in {torch.cat, torch.stack}: + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> List[Union[torch.fx.Node, _Range]]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = set(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> List[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = set(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: List[Union[torch.fx.Node, int]] + ) -> List[Union[torch.fx.Node, _Range]]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + ranges = set() + for user_node, user_inputs in zip(next_users, user_inputs_list): + ranges |= { + user_input + for user_input in user_inputs + if isinstance(user_input, tuple) + } + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool: + for range_, next_range in zip(ranges, ranges[1:]): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = _get_dim(split_node) + split_sections = split_node.args[1] + transform_params_list: List[List[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in {torch.cat, torch.stack}: + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target == torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + 1 + ] + # All sections should be equal + if len(set(subset_split_sections)) != 1: + return None + + num_splits = len(subset_split_sections) + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target == torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: List[int], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + split_ranges: List[_Range], + ) -> List[List[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = _get_dim(split_node) + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] + new_split.meta["example_value"] = torch.split( + split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr] + ) + counters["inductor"]["scmerge_split_added"] += 1 + split_items = [] + with graph.inserting_after(new_split): + for i in range(len(split_ranges)): + getitem = graph.call_function(operator.getitem, args=(new_split, i)) + if is_node_meta_valid(new_split): + getitem.meta["example_value"] = new_split.meta["example_value"][ + i + ] + split_items.append(getitem) + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list_new, + transform_params_list: List[List[_TransformParam]], + ): + split_dim = _get_dim(split_node) + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in {torch.cat, torch.stack}: + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed, user_inputs_new_transformed_meta = [], [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack, to_stack_meta = [], [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + if not is_node_meta_valid(user_input_new): + log.debug("example value absent for node: %s", user_input_new) + return + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + to_stack_meta.append(user_input_new.meta["example_value"]) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] + to_stack, to_stack_meta = [], [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + to_stack_meta.append(user_input_new.meta["example_value"]) + continue + + if unflatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] + if movedim_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr] + if flatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] + user_inputs_new_transformed.append(user_input_new) + user_inputs_new_transformed_meta.append( + user_input_new.meta["example_value"] + ) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + user_inputs_new_transformed_meta, dim=cat_dim + ) + counters["inductor"]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + new_cat_node.meta[ + "example_value" + ] = user_inputs_new_transformed_meta[-1] + + if ( + user_node.target == torch.cat + and split_dim != cat_dim + and split_node.target == torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node_meta = new_cat_node.meta["example_value"] + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr] + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.GraphModule, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + ): + to_remove = [split_node] + counters["inductor"]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in {torch.cat, torch.stack}: + continue + counters["inductor"]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + if len(node.users.keys()) == 0: + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + if not is_node_meta_valid(unbind_node): + return + # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node + # before we do the unbind remove, otherwise it will hit the error when we unbind part of them + getitem_indices = [] + for getitem_node in unbind_node.users.keys(): + getitem_indices.append(getitem_node.args[1]) + if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] + getitem_indices + ) != len( + unbind_node.meta["example_value"] + ): + return + num_unbind = len(getitem_indices) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: List[int], + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[_Range]]: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: List[torch.fx.Node], + user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], + ) -> Optional[List[List[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = _get_dim(split_node) + transform_params_list: List[List[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: List[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target == torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target == torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1) -> None: + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target == torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target == torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + if is_node_meta_valid(split_input): + unbind.meta["example_value"] = torch.unbind( + split_input.meta["example_value"], dim=dim + ) + for item_index, getitem_node in sorted( + [ + (getitem_node.args[1], getitem_node) + for getitem_node in split.users.keys() + ] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters["inductor"]["split_cat_pass"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +reshape_getitem_split = ListOf( + CallFunction( + torch.reshape, + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def simplify_split_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: + return False + return True + + +def remove_zeros(split_sections: List[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: List[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"), +) +def merge_getitem_cat(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + indices = [] + for arg in cat_user.args[0]: # type: ignore[union-attr] + indices.append(arg.args[1]) # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( + split_node, indices + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["merge_getitem_cat_pass"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"), +) +def mutate_cat_node(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the gettitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, list(range(indices[0])) + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, indices + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.debug("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target == torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["unbind_stack_aten_pass"] += 1 + + +def divide_into_consecutive_sublists(indices: List[int]) -> List[List[int]]: + n = len(indices) + if n <= 1: + return [indices] + + # Initialize the list of sublists + sublists = [] + + # Iterate over the indices + i = 0 + while i < n: + # Initialize the current sublist + sublist = [indices[i]] + + # Iterate over the remaining indices + j = i + 1 + while j < n and indices[j] == indices[j - 1] + 1: + # Add the next index to the current sublist + sublist.append(indices[j]) + j += 1 + + # Add the current sublist to the list of sublists + sublists.append(sublist) + # Move to the next index + i = j + + return sublists + + +def update_args_from_split_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, + getitem_indices: List[int], + parents_seen: List[torch.fx.Node], + new_cat_args: List[torch.fx.Node], + new_cat_args_meta: List[torch.fx.Node], + idx_to_getitems: Dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) + # case 1: the number of getitems is the same as the split size, elimiate the split + if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( + getitem_indices + ): + # we can merge the getitems from the previous parent + new_cat_args.append(split_input) + new_cat_args_meta.append(split_input.meta["example_value"]) + else: + if len(getitem_indices) > 0: + # case 2: the number of getitems is smaller than the split size but larger than the threshold, and + # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups + geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices) + for sublist in geitem_indices_sublist: + if len(sublist) >= threshold_to_cat: + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + start_fused_size = sum(split_size[: sublist[0]]) + end_fused_size = sum(split_size[: sublist[-1] + 1]) + slice_list = [] + for i in range(len(split_input.meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append( + slice(start_fused_size, end_fused_size, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(split_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = split_input.meta[ + "example_value" + ][tuple(slice_list)] + new_cat_args.append(slice_node) + new_cat_args_meta.append(slice_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in sublist: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append( + idx_to_getitems[i].meta["example_value"] + ) + + +def reshape_cat_node( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + unbind_input: torch.fx.Node, + cat_dim: int, + unbind_dim: int, + cat_shape: torch.Size, +) -> torch.fx.Node: + if cat_dim != unbind_dim: + # construct the permute node args, which has the same shape as the slice node + # then it has the same dim as the unbind_input, i.e., shape of cat + 1 + with graph.inserting_after(cat_node): + permute_list = list(range(len(cat_shape) + 1)) + permute_list[unbind_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[unbind_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(unbind_input, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + unbind_input.meta["example_value"], permute_list + ) # type: ignore[arg-type] + else: + permute_node = unbind_input + with graph.inserting_after(permute_node): + reshape_node = graph.call_function( + torch.reshape, args=(permute_node, tuple(cat_shape)) + ) + reshape_node.meta["example_value"] = torch.reshape( + permute_node.meta["example_value"], tuple(cat_shape) + ) # type: ignore[arg-type] + return reshape_node + + +def update_args_from_unbind_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, # cat or stack node + getitem_indices: List[int], + parents_seen: List[torch.fx.Node], + new_cat_args: List[torch.fx.Node], + new_cat_args_meta: List[torch.fx.Node], + idx_to_getitems: Dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input + unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim + cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim + # case 1: the number of getitems is the same as the split size, elimiate the split + size = list(unbind_input.meta["example_value"].shape)[unbind_dim] + if size == len(getitem_indices): + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + # we can merge the getitems from the previous parent + reshape_node = reshape_cat_node( + graph, node, unbind_input, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive( + getitem_indices + ): + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + slice_list = [] + for i in range(len(cat_shape) + 1): + if i != unbind_dim: + slice_list.append(slice(None, None, None)) # start, end, step + else: + slice_list.append( + slice(getitem_indices[0], getitem_indices[-1] + 1, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(unbind_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = torch.narrow( + unbind_input.meta["example_value"], + unbind_dim, + getitem_indices[0], + getitem_indices[-1] - getitem_indices[0] + 1, + ) + reshape_node = reshape_cat_node( + graph, node, slice_node, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in getitem_indices: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"]) + + +def construct_cat_args( + graph: torch.fx.Graph, + cat_or_stack_node: torch.fx.Node, + inputs: List[torch.fx.Node], + split_or_unbind_node: torch.fx.Node, + threshold_to_cat: int = 2, + run_update_func: Callable = update_args_from_split_getitem, # type: ignore[type-arg] +) -> Tuple[List[torch.fx.Node], List[torch.Tensor]]: + new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {} # type: ignore[var-annotated] + new_cat_args_meta = [] # type: ignore[var-annotated] + for input in inputs: + if input.target != operator.getitem: + # update the last arg based on getitem_indices and parents_seens + if len(parents_seen) > 0: + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + # reset the indices array + getitem_indices, idx_to_getitems = [], {} + else: + # get the parent node of the getitem input + parent, idx = input.args[0], input.args[1] # type: ignore[union-attr] + if parent.target != split_or_unbind_node.target: # type: ignore[union-attr] + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # cannot use parents_seen to check since the first item could be non getitem node + if len(parents_seen) == 0: + parents_seen.append(parent) + idx_to_getitems[idx] = input + getitem_indices.append(idx) + # case: we only have one getitem input, and it is in the last position + if input == inputs[-1]: + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # if it is the last input in the tensors, we also check if it can be optimized + if parent != parents_seen[-1] or input == inputs[-1]: + if input == inputs[-1]: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + # reset the indices array for the next parent + # remember to add the last element since it is the first + # item in this round of parent + # add the parent to the list of seen parents + parents_seen.append(parent) + getitem_indices, idx_to_getitems = [idx], {idx: input} + else: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + return new_cat_args, new_cat_args_meta + + +def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.Node]): + nodes = set() + for input in inputs: + if input.target == operator.getitem: + nodes.add(input.args[0]) # type: ignore[union-attr] + if len(input.users.keys()) == 0: + graph.erase_node(input) + # check the split node to remove if it has no users + for node in nodes: + if len(node.users.keys()) == 0: # type: ignore[union-attr] + graph.erase_node(node) # type: ignore[arg-type] + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# other inputs getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) other inputs -> -> user=multiple +# / \ +# cat (user=mul, dim=1, split_node) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +def split_cat_to_slices(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_nodes = [node for node in match.nodes if node.target == torch.split] + if split_nodes: + split_node = next(node for node in split_nodes) + else: + # Handle the case where there are no nodes with a target of torch.split + return + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_cat_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + next_users = find_next_users(split_node) + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, _ = construct_cat_args( + graph, + cat_node, + cat_inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: if new cat args has length 1, we can remove the cat node + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["split_cat_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): + new_args = (new_cat_args,) + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + # split and cat have the same dim + kwargs={"dim": split_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) + counters["inductor"]["split_cat_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=0) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), +) +def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_cat_to_view_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + cat_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + # get the view shape + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(cat_node, 1, "dim") + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type] + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + + +def reshape_cat_node_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + split_or_unbind_dim: int, +) -> None: + # reshape the cat node to the stack node shape + stack_shape = stack_node.meta["example_value"].shape + stack_dim = _get_dim(stack_node) + if stack_dim != split_or_unbind_dim: + # case 1: the stack dim is not the same as the split dim + # we need to reshape the split input before we do the reshape + reshape_list = list(stack_shape) + reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = ( + reshape_list[split_or_unbind_dim], + reshape_list[stack_dim], + ) + reshape_node = graph.call_function( + torch.reshape, + args=(cat_node, tuple(reshape_list)), + ) + reshape_node.meta["example_value"] = torch.reshape( + cat_node.meta["example_value"], tuple(reshape_list) + ) + permute_list = list(range(len(stack_shape))) + permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( + permute_list[split_or_unbind_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(reshape_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + reshape_node.meta["example_value"], permute_list + ) + else: + # case 2: the stack dim is the same as the split dim + # we can directly reshape the split input + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, *stack_shape), # type: ignore[arg-type] + ) + stack_node.replace_all_uses_with(reshape_node) + reshape_node.meta.update(stack_node.meta) + stack_inputs = stack_node.args[0] # type: ignore[union-attr] + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + + +def convert_reshape_cat_arg_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + stack_node_shape: torch.Size, + stack_dim: int, + split_dim: int, +) -> torch.fx.Node: + # reshape the cat node to the stack node shape + cat_shape = cat_node.meta["example_value"].shape + if stack_dim != split_dim: + permute_list = list(range(len(cat_shape))) + permute_list[stack_dim], permute_list[split_dim] = ( + permute_list[split_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(cat_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + cat_node.meta["example_value"], permute_list + ) + else: + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] + ) + reshape_node.meta["example_value"] = torch.Tensor.view( + permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type] + ) + return reshape_node + + +# ############pattern to be optimized is######### +# | | +# split split (dim=1) +# / \ / \ +# getitem ... getitem other ops +# \ | / / +# stack(user=mul, dim=1 or 2) -> can be different dim +# | + +# ################after transformation############# + +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / +# cat(user=mul, dim=1) cat_other_opts +# \ / +# cat +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"), +) +def split_stack_to_cats(match: Match, split_sections: List[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_stack_to_cats_pass" + ].get("threshold_to_cat", 10) + # get the stack_node and check its inputs and meta data + next_users = find_next_users(split_node) + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": split_dim}, + ) + cat_node.meta["example_value"] = torch.cat( # type: ignore[arg-type] + new_cat_args_meta, dim=split_dim + ) + reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=1) -> user=multiple +# \ ... / \ +# others getitem getitem getitem -> user=multiple +# \ \ / \ +# stack(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view others +# | / +# stack +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), +) +def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_stack_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0 + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(stack_node, 1, "dim") + with graph.inserting_after(stack_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) + reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### +# input +# | +# split(dim=1) -> user=multiple +# \ \ +# others getitem getitem +# \ \ / +# reshape reshape reshape other_op +# \ \ / / +# stack(user=mul, dim=0) +# | + +# ################after transformation############# +# input +# | +# permute +# | +# reshape others +# | / +# cat (dim=0) +# | + + +def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> List[int]: + # cat_arg must be the split input + view_shape_list = [] + for user in cat_arg.users.keys(): + if user.target == torch.split: + for getitem in user.users.keys(): + if getitem.target == operator.getitem: + reshape_user = [ + user + for user in getitem.users.keys() + if user.target == torch.reshape + ] + if len(reshape_user) > 0: + view_shape_list = list( + reshape_user[0] + .meta["example_value"] + .unsqueeze(stack_dim) + .shape + ) + view_shape_list[stack_dim] = -1 + return view_shape_list + return view_shape_list + + +@register_graph_pattern( + CallFunction( + torch.stack, + reshape_getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), +) +def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = _get_dim(split_node) + split_users = list(split_node.users.keys()) + stack_nodes = [node for node in match.nodes if node.target == torch.stack] + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "move_reshape_out_of_split_stack_pass" + ].get("threshold_to_cat", 10) + for stack_node in stack_nodes: + if not is_node_meta_valid(stack_node): + log.debug("example value absent for node: %s", stack_node) + continue + stack_dim = _get_dim(stack_node) + stack_inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + inputs = [] + for stack_input in stack_inputs: + if stack_input.target != torch.reshape: + inputs.append(stack_input) + else: + inputs.append(stack_input.args[0]) # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_node = convert_reshape_cat_arg_to_stack( + graph, + new_cat_args[0], + stack_node, + stack_node.meta["example_value"].shape, + stack_dim, + split_dim, + ) + stack_node.replace_all_uses_with(reshape_node) + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # decompose the cat args into multiple stack nodes, i.e., we stack + # all the nodes exist in the stack inputs and reshape the rest followed by a cat + stack_node_input, stack_node_input_meta, cat_inputs = [], [], [] # type: ignore[var-annotated] + for cat_arg in new_cat_args: + if cat_arg not in stack_inputs: + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + # cat_arg must be the split input + view_shape_list = get_view_shape_list(cat_arg, stack_dim) + stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr] + cat_inputs.append( + convert_reshape_cat_arg_to_stack( + graph, + cat_arg, + stack_node, + stack_node_shape, + stack_dim, + split_dim, + ) + ) + stack_node_input, stack_node_input_meta = [], [] + else: + stack_node_input.append(cat_arg) + stack_node_input_meta.append(cat_arg.meta["example_value"]) + + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(cat_inputs,), + kwargs={"dim": stack_dim}, + ) + stack_node.replace_all_uses_with(cat_node) + cat_node.meta.update(stack_node.meta) + graph.erase_node(stack_node) + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/__init__.py b/.venv/Lib/site-packages/torch/_inductor/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba0fe9754f43c39cd3adc2187ccb68c1890605c8 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/__init__.py @@ -0,0 +1 @@ +from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/bmm.py b/.venv/Lib/site-packages/torch/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..5dee0b1eca32dff486176f12e8809fba748ce0b5 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/bmm.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import logging + +import torch + +from .. import ir, lowering as L +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ( + ceildiv as cdiv, + use_aten_gemm_kernels, + use_cutlass_template, + use_triton_template, +) +from ..virtualized import V +from .mm import _is_static_problem +from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +def bmm_grid(b, m, n, meta): + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) + + +bmm_template = TritonTemplate( + name="bmm", + grid=bmm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", -2)}} + N = {{size("B", -1)}} + K = {{size("A", -1)}} + + stride_aq = {{stride("A", 0)}} + stride_am = {{stride("A", 1)}} + stride_ak = {{stride("A", 2)}} + + stride_bq = {{stride("B", 0)}} + stride_bk = {{stride("B", 1)}} + stride_bn = {{stride("B", 2)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}} +""", +) + +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") +aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out") + + +@L.register_lowering(aten.bmm) +def tuned_bmm(mat1, mat2, *, layout=None): + if all(x.get_device().type == "cpu" for x in [mat1, mat2]): + # decompose to small ops when memory bound + if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1: + mat1 = L.unsqueeze(mat1, -1) + mat2 = L.unsqueeze(mat2, 1) + return L.sum_(L.mul(mat1, mat2), axis=2) + + def is_valid_to_require_contiguous(t): + if not ir.is_storage_and_layout(t): + return True + _, layout = ir.as_storage_and_layout(t, freeze=False) + return isinstance(layout, ir.FlexibleLayout) + + def is_preferred_layout_as_bmm_input(sizes, strides): + # contiguous on one of the last two dims + return ( + strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1]) + ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2])) + + # Make the input of bmm contiguous + # if it is not contiguous on either of the last two dims, + # because bmm cpu implementation would do contiguous() if not. + # This is to avoid additional copies in bmm. + def may_require_contiguous(t, meta_t): + sizes = meta_t.meta["val"].size() + strides = meta_t.meta["val"].stride() + if not is_preferred_layout_as_bmm_input(sizes, strides): + t = ir.ExternKernel.require_contiguous(t) + return t + + if is_valid_to_require_contiguous(mat1): + meta_mat1 = V.graph.current_node.args[0] + mat1 = may_require_contiguous(mat1, meta_mat1) + if is_valid_to_require_contiguous(mat2): + meta_mat2 = V.graph.current_node.args[1] + mat2 = may_require_contiguous(mat2, meta_mat2) + + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + + # options to tune from + choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + if use_triton_template(layout): + for config in mm_configs(m, n, k): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout) + if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): + from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate + + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + + if len(choices) == 0: + log.warning("No choices for GEMM, using ATen backend as fallback") + choices.append(aten_bmm.bind((mat1, mat2), layout)) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) + + +# Don't register this since it is slower than decomposing it +# @L.register_lowering(aten.baddbmm) +def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + + # options to tune from + choices = ( + [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + if use_aten_gemm_kernels() + else [] + ) + if use_triton_template(layout): + for config in mm_configs(m, n, k): + bmm_template.maybe_append_choice( + choices, + input_nodes=(inp, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/conv.py b/.venv/Lib/site-packages/torch/_inductor/kernel/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..f4714158d7bdfceef01ff09b8320474791bdd649 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/conv.py @@ -0,0 +1,679 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict + +import torch + +from .. import config, ir +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import ( + ceildiv, + is_ones, + is_zeros, + pad_listlike, + sympy_product, + use_triton_template, +) +from ..virtualized import V +from .mm_common import filtered_configs + + +if TYPE_CHECKING: + from ..ir import TensorBox + +log = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +def conv2d_grid(n, c, h, w, meta): + return ( + ceildiv(n * h * w, meta["BLOCK_M"]), + ceildiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +def conv3d_grid(n, c, d, h, w, meta): + return ( + ceildiv(n * d * h * w, meta["BLOCK_M"]), + ceildiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +# List of dictionaries to store the kernel configs. Configs that evaluate to true +# will be utilised on the target platform +kernel_configs = [ + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" + {"config": (64, 256, 16, 2, 4), "cond": True}, + {"config": (256, 64, 16, 2, 4), "cond": True}, + {"config": (1024, 16, 16, 1, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 256, 32, 2, 8), "cond": True}, + {"config": (256, 64, 32, 2, 8), "cond": True}, +] + +# Create filtered list of configs based on conv +platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in kernel_configs + if config["cond"] +) + +# On ROCm convert num_stages to 1 as pipelining provides no benefit +if torch.version.hip: + platform_configs = tuple( + (config[0], config[1], config[2], 1, config[4]) for config in platform_configs + ) + +conv_configs = functools.partial( + filtered_configs, + configs=platform_configs, +) + +LOOP_BODY_2D = """ + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +""" +This is a relatively simple conv implementation that can likely be +improved. Many alternate conv versions can be found here: +https://github.com/pytorch/torchdynamo/pull/971 +""" +conv2d_template = TritonTemplate( + name="convolution2d", + grid=conv2d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_H = {{size("X", 2)}} + IN_W = {{size("X", 3)}} + OUT_C = {{size(None, 1)}} + OUT_H = {{size(None, 2)}} + OUT_W = {{size(None, 3)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xh = {{stride("X", 2)}} + stride_xw = {{stride("X", 3)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wh = {{stride("W", 2)}} + stride_ww = {{stride("W", 3)}} + + nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = nhw % OUT_W + nh = nhw // OUT_W + idx_y_h = nh % OUT_H + idx_n = nh // OUT_H + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_2D + + """ +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (ijk % BLOCK_K_COUNT) * BLOCK_K + ij = ijk // BLOCK_K_COUNT + i = ij // KERNEL_W + j = ij % KERNEL_W + """ + + LOOP_BODY_2D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +LOOP_BODY_3D = """ + idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_d * stride_xd)[:, None] + + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_d >= 0)[:, None] + & (idx_x_d < IN_D)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + + (d * stride_wd) + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +conv3d_template = TritonTemplate( + name="convolution3d", + grid=conv3d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_D = {{size("X", 2)}} + IN_H = {{size("X", 3)}} + IN_W = {{size("X", 4)}} + OUT_C = {{size(None, 1)}} + OUT_D = {{size(None, 2)}} + OUT_H = {{size(None, 3)}} + OUT_W = {{size(None, 4)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xd = {{stride("X", 2)}} + stride_xh = {{stride("X", 3)}} + stride_xw = {{stride("X", 4)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wd = {{stride("W", 2)}} + stride_wh = {{stride("W", 3)}} + stride_ww = {{stride("W", 4)}} + + ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = ndhw % OUT_W + ndh = ndhw // OUT_W + idx_y_h = ndh % OUT_H + nd = ndh // OUT_H + idx_y_d = nd % OUT_D + idx_n = nd // OUT_D + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for d in range(KERNEL_D) %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + d = {{d}} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_3D + + """ +{% endfor %} +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for d in range(KERNEL_D): + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (dijk % BLOCK_K_COUNT) * BLOCK_K + dij = dijk // BLOCK_K_COUNT + j = dij % KERNEL_W + di = dij // KERNEL_W + i = di % KERNEL_H + d = di // KERNEL_H + """ + + LOOP_BODY_3D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_d < OUT_D)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_d = idx_y_d[:, None] + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +aten_convolution = ExternKernelChoice( + torch.convolution, + "at::convolution", + has_out_variant=False, + op_overload=aten.convolution.default, +) + + +def conv1x1_via_mm(x, w, *, out): + w = torch.squeeze(torch.squeeze(w, -1), -1) + return torch.matmul( + x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1) + ) + + +aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None) + + +class ConvLayoutParams(TypedDict): + stride: tuple[int, ...] + padding: tuple[int, ...] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + + +def conv_layout( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, +) -> ir.Layout: + """Determine output layout for a convolution""" + with V.graph.fake_mode: + output = torch.ops.aten.convolution( + ir.ir_node_to_tensor(x, guard_shape=True), + ir.ir_node_to_tensor(weight, guard_shape=True), + ir.ir_node_to_tensor(bias, guard_shape=True), + V.graph.sizevars.size_hints(stride), # type: ignore[arg-type] + V.graph.sizevars.size_hints(padding), # type: ignore[arg-type] + V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type] + transposed, + V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type] + groups, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] + + return ir.FixedLayout( + x.get_device(), + x.get_dtype(), + sizes, + stride, + ) + + +def channels_last_order(rank): + order = list(reversed(range(rank))) + order.insert(1, order.pop(-1)) + return order + + +def convert_1x1_conv_to_mm(x, weight, bias): + # special case for 1x1 convolution, which is actually just a matmul + rank = len(weight.get_size()) + for _ in range(rank - 2): + weight = L[aten.squeeze](weight, dim=-1) + weight = L[aten.permute](weight, [1, 0]) + + x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank)) + x_permute = list(range(rank)) + x_permute.append(x_permute.pop(1)) + x = L[aten.permute](x, x_permute) + *sizes, in_chan = x.get_size() + x = L[aten.reshape](x, [sympy_product(sizes), in_chan]) + if bias is None: + result = L[aten.mm](x, weight) + else: + result = L[aten.addmm](bias, x, weight) + result = L[aten.reshape](result, [*sizes, -1]) + result_permute = list(range(rank)) + result_permute.insert(1, result_permute.pop(-1)) + return L[aten.permute](result, result_permute) + + +@register_lowering(aten.convolution) +def convolution( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, +): + stride = tuple(stride) + padding = tuple(padding) + dilation = tuple(dilation) + output_padding = tuple(output_padding) + if not isinstance(groups, int): + groups = V.graph.sizevars.evaluate_static_shape(groups) + assert isinstance(groups, int) + + # Need use hint for triton template since the template does not + # work with a dynamic shape. + # + # No need to evaluate_static_shape for dilation and output_padding + # since the template is only used when dilation is 1 and output_padding + # is 0. + stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) + padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) + + kwargs: ConvLayoutParams = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + } + + if len(x.get_size()) == len(weight.get_size()) - 1: + # add batch dimension to simplify rest of function + return L[aten.squeeze]( + convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs), + dim=0, + ) + + out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( + weight.get_size() + ) + ndim = len(kernel_shape) + stride = pad_listlike(stride, ndim) + padding = pad_listlike(padding, ndim) + dilation = pad_listlike(dilation, ndim) + output_padding = pad_listlike(output_padding, ndim) + + def channels_last_conv(): + if V.graph.layout_opt and ndim == 2: + return True + + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + return req_stride_order == ir.NHWC_STRIDE_ORDER + + autotuning_gemm = config.max_autotune or config.max_autotune_gemm + + if ( + (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv())) + and is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + and groups == 1 + and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0) + ): + return convert_1x1_conv_to_mm(x, weight, bias) + + if bias is not None and ir.get_device_type(x) != "cpu": + # peel off the bias, cudnn is slower with it + result = convolution(x, weight, None, **kwargs) + return L[aten.add]( + result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1]) + ) + + x.realize() + weight.realize() + + # ndim can be 1 for convolution in models such as demucs + # TODO: check if it's beneficial to convert Conv1d to Conv2d and then + # apply channels last. + if V.graph.layout_opt and ndim == 2: + V.graph.num_channels_last_conv += 1 + x = ir.ExternKernel.require_channels_last(x) + # TODO maybe we can convert weights to channels last just once before + # running the model. + weight = ir.ExternKernel.require_channels_last(weight) + layout = conv_layout(x, weight, None, **kwargs) + else: + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + x = ir.ExternKernel.require_stride_order(x, req_stride_order) + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) + + ordered_kwargs_for_cpp_kernel = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + if bias is None: + args = [x, weight] + kwargs["bias"] = None # type: ignore[typeddict-unknown-key] + ordered_kwargs_for_cpp_kernel.insert(0, "bias") + else: + args = [x, weight, bias] + bias.realize() + bias.freeze_layout() + V.graph.sizevars.evaluate_static_shapes(bias.get_size()) + + choices = [] + if torch._inductor.utils._use_conv_autotune_backend("ATEN"): + choices = [ + aten_convolution.bind( + args, + layout, + ordered_kwargs_for_cpp_kernel, + **kwargs, + ) + ] + + if ( + torch._inductor.utils._use_conv_autotune_backend("TRITON") + and use_triton_template(layout) + # templates only support these: + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) + and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type] + ): + if ( + is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and groups == 1 + ): + choices.append(aten_conv1x1_via_mm.bind(args, layout)) + + for cfg in conv_configs( + sympy_product([x.get_size()[0], *x.get_size()[2:]]), + out_chan, + in_chan, + ): + if ndim == 2: + conv2d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_H=kernel_shape[0], + KERNEL_W=kernel_shape[1], + STRIDE_H=stride[0], + STRIDE_W=stride[1], + PADDING_H=padding[0], + PADDING_W=padding[1], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/openai/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + elif ndim == 3: + conv3d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_D=kernel_shape[0], + KERNEL_H=kernel_shape[1], + KERNEL_W=kernel_shape[2], + STRIDE_D=stride[0], + STRIDE_H=stride[1], + STRIDE_W=stride[2], + PADDING_D=padding[0], + PADDING_H=padding[1], + PADDING_W=padding[2], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/openai/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + + return autotune_select_algorithm("convolution", choices, args, layout) + + +@register_lowering(aten._convolution) +def _convolution( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32, +): + return convolution( + x, weight, bias, stride, padding, dilation, transposed, output_padding, groups + ) + + +def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): + assert fx_node.target == torch.ops.aten.convolution.default + if V.graph.layout_opt: + return args, kwargs + else: + return constrain_to_fx_strides(fx_node, *args, **kwargs) + + +add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides) diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py b/.venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9001bab0b4ed66accf9d8724252a91d0aaad048b --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py @@ -0,0 +1,1843 @@ +# mypy: allow-untyped-defs +""" Triton Implementation of the flex_attention Kernel""" + +import logging +import math +from typing import Any, List, Optional, Sequence, Tuple + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._pytree import tree_map + +from .. import config +from ..ir import ( + ComputedBuffer, + ExternKernel, + FixedLayout, + FlexibleLayout, + get_stride_order, + InputBuffer, + IRNode, + StorageBox, + stride_order2fill_order, + Subgraph, + TensorBox, +) +from ..lowering import empty, empty_strided, lowerings, register_lowering +from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) + Each block is responsible for iterating over blocks of keys and values calculating + the final attention output. + """ + import triton + + return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1) + + +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], [])) + return TensorBox.create(input_buffer) + + +def maybe_realize(args: List[Optional[IRNode]]): + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" + return tree_map(lambda x: realize_inputs(x) if x is not None else None, args) + + +def get_float32_precision(): + if torch.get_float32_matmul_precision() == "highest" or torch.version.hip: + return "'ieee'" + else: + return "'tf32'" + + +def build_subgraph_buffer( + args: List[TensorBox], + subgraph: Subgraph, +): + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. + subgraph: The Subgraph ir for which to produce the output node + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + # There are two classes of placeholder inpts that we need + # to handle differently. For the first n_scalar_inps inputs + # we expect that these placeholders were generated by the make_fx call + # in the flex Attention HOP. So we need to create a new placeholder + # TensorBox for each of these inputs. For the rest of the inputs we + # expect that these are lifted inputs that fill up the '*other_buffers' + # tuple and already have corresponding TensorBoxes passed in as args. + if node.op == "placeholder": + env[node] = args[cnt] + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) + ) + env[node] = lowerings[node.target](*args, **kwargs) + elif node.op == "output": + + def convert_output_node_to_buffer(output): + if output is None: + return None + output_node = output + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + # node.args[0] is either a single element or a list of elements + # representing all outputs of the function. + return tree_map(convert_output_node_to_buffer, node.args[0]) + + raise ValueError("FlexAttention was passed a subgraph with no output node!") + + +# Inner Triton functions shared by flex_attention & split-k decoding kernels. +compute_next_offset_func = r""" +@triton.jit +def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK): + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset +""" + +compute_flex_attention = r""" +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + Z = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_z = tl.program_id(1) // HQ + off_hq = tl.program_id(1) % HQ + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_z * stride_qz + off_hq * stride_qh + k_offset = off_z * stride_kz + off_hkv * stride_kh + v_offset = off_z * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950 + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM), + order=(1, 0) + ) + + # load q: it stays in SRAM throughout the inner loop. + if IS_DIVISIBLE: + q = tl.load(Q_block_ptr) + else: + # boundary check is not free, so we only do it when necessary. + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero") + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_z, off_hq, offs_m[:, None], offs_n[None, :], + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_z, off_hq, offs_m[:, None], offs_n[None, :], + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_z = tl.program_id(1) // HQ + idx_hq = tl.program_id(1) % HQ + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM)[None, :] + + mask = idx_m < Q_LEN + # TODO generalize and add proper mask support + {{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + # TODO dont want to write this if we dont require grad + if OUTPUT_LOGSUMEXP: + off_hz = tl.program_id(1) + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + """ + + +compute_forward_inner = r""" +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + # update pointers + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N + ) + + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + + offs_n = offs_n + offset + + return acc, l_i, m_i + +""" + + +compute_forward_block_mn = r""" +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + if IS_DIVISIBLE: + k = tl.load(K_block_ptr) + else: + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero") + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + if CHECK_BLOCK_BOUNDARY: + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = offs_m % Q_LEN + n = offs_n % KV_LEN + else: + m = offs_m + n = offs_n + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf")) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + # TODO: In the case that score_mod is linear, this can be LICMed + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + + if IS_DIVISIBLE: + v = tl.load(V_block_ptr) + else: + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero") + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +""" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=compute_flex_attention + + compute_forward_inner + + compute_next_offset_func + + compute_forward_block_mn, +) + + +def _use_flex_decoding(query, kernel_options): + # Decide which kernel to use, return true if use flex decoding kernel. + return ( + not kernel_options.get("FORCE_USE_FLEX_ATTENTION", False) + ) and V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 128)) + + +_h100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (32, 64, 4, 3), + (torch.float32, 256): (32, 32, 4, 3), + (torch.bfloat16, 64): (128, 128, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (64, 32, 4, 3), + (torch.float16, 64): (128, 128, 4, 3), + (torch.float16, 128): (128, 128, 8, 3), + (torch.float16, 256): (64, 32, 4, 3), +} + +_a100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (128, 32, 4, 3), + (torch.float32, 256): (64, 16, 4, 3), + (torch.bfloat16, 64): (128, 64, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (32, 64, 4, 3), + (torch.float16, 64): (128, 64, 4, 3), + (torch.float16, 128): (128, 64, 8, 3), + (torch.float16, 256): (32, 64, 4, 3), +} + + +def _get_default_config_fwd(query) -> Tuple[int, int, int, int]: + dtype = query.get_dtype() + head_dim = query.get_size()[-1] + default_config = None + + if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 + if dtype == torch.float32: + default_config = (64, 64, 4, 3) + else: + default_config = (128, 64, 4, 3) + default_config = _h100_default_config.get((dtype, head_dim), default_config) + elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 + if dtype == torch.float32: + default_config = (64, 64, 4, 3) + else: + default_config = (128, 64, 4, 3) + default_config = _a100_default_config.get((dtype, head_dim), default_config) + else: # modest hardware or extremely large head_dim + if dtype == torch.float32: + default_config = (32, 16, 4, 3) + else: + default_config = (64, 32, 4, 3) + + return default_config + + +def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: + head_dim = query.get_size()[-1] + dtype = query.get_dtype() + + if dtype == torch.float32: + return (16, 16, 4, 1) + if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 + if head_dim == 64: + return (64, 64, 4, 3) + elif head_dim == 128: + return (64, 128, 8, 3) + else: + return (64, 64, 4, 2) + elif torch.cuda.get_device_capability() >= (8, 0): # A100 + if head_dim == 64: + return (32, 128, 4, 3) + elif head_dim == 128: + return (64, 128, 8, 3) + else: + return (64, 64, 4, 2) + else: # modest hardware or extremely large head_dim + return (16, 16, 4, 1) + + +def create_num_blocks_fake_generator(sparse_indices): + # The idea here is that we need to create a real tensor with real data + # that's representative for benchmarking. + # For example, returning all zeros for the `kv_num_blocks` input would mean + # that we are computing 0 blocks for each row, which would provide bogus + # autotuning results. + # + # In this case, we choose to use min(16, max_block) blocks, because I + # (Horace) think it'll probably result in pretty representative performance. + # If it's too short then prefetching won't help. If it's too long then + # autotuning will take longer for no good reason. + def create_num_blocks_fake(x) -> torch.Tensor: + num_blocks_for_autotuning = min(16, sparse_indices.shape[-1]) + return torch.full( + x.get_size(), + int(num_blocks_for_autotuning), + dtype=x.get_dtype(), + device=x.get_device(), + ) + + return create_num_blocks_fake + + +def create_indices_fake(x) -> torch.Tensor: + indices = torch.arange( + 0, int(x.get_size()[-1]), dtype=x.get_dtype(), device=x.get_device() + ) + indices = indices.expand(x.get_size()).contiguous() + return indices + + +from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel + + +# TODO: We probably also need a layout constraint? +@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) +def flex_attention( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + ( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_KV_BLOCK_SIZE, + SPARSE_Q_BLOCK_SIZE, + mask_graph, + ) = block_mask + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + kernel_options = dict(kernel_options) + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + if _use_flex_decoding(query, kernel_options): + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert Bq == Bkv, "Batch dimension must match" + B = Bq + + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # Reuse query strides for output layout despite different last dimension. + # This works because only the last dim differs and we check it is contiguous. + q_strides = query.get_stride() + assert q_strides[-1] == 1, "Query must be contiguous in the last dimension" + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + stride_order = get_stride_order(query.get_stride()) + fill_order = stride_order2fill_order(stride_order) + out_strides = construct_strides(out_size, fill_order) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=out_strides, + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = [B, Hq, seq_len_q] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA broadcast factor. + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is None if partial blocks are not computed + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim) + + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_fwd(query)) + if config.max_autotune: + configs += [ + (128, 64, 4, 3), + (128, 128, 4, 3), + (128, 128, 8, 2), + (64, 128, 4, 3), + (64, 64, 4, 3), + ] + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0: + continue + # Work around https://github.com/pytorch/pytorch/issues/129625 + if num_stages == 2: + continue + + # Performance tuning + kernel_options.setdefault("BLOCK_M", BLOCK_M) + kernel_options.setdefault("BLOCK_N", BLOCK_N) + # Blocksparse options + kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + call_sizes=query.get_size(), + **kernel_options, + ) + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 4: create_num_blocks_fake_generator(kv_indices), + 5: create_indices_fake, + 6: create_num_blocks_fake_generator(full_kv_indices), + 7: create_indices_fake, + } + return ( + autotune_select_algorithm( + "flex_attention", + choices, + inputs_for_autotuning, + layout, + input_gen_fns=input_gen_fns, + ), + logsumexp, + ) + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid( + batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta +): + """How is this kernel parallelized? + Currently this is only parallelizing over batch* kv_heads, but we can, and want to + parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). + To do this will either require atomic updates to some grad values or to have a two pass kernel design. + """ + import triton + + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * kv_heads, + ) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + Z = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_hz = tl.program_id(2) + off_z = off_hz // HKV # batch idx + off_hkv = off_hz % HKV # kv head idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64) + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM) + offs_v = tl.arange(0, V_HEAD_DIM) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64) + off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + if IS_DIVISIBLE: + q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd) + do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod) + else: + q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=offs_m2[:, None] < Q_LEN) + do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod, mask=offs_m2[:, None] < Q_LEN) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_z, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_z, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=offs_m2[:, None] < Q_LEN) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + if IS_DIVISIBLE: + k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd) + v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd) + else: + k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=offs_n1[:, None] < KV_LEN) + v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd, mask=offs_n1[:, None] < KV_LEN) + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64) + off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_z, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_z, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + + if IS_DIVISIBLE: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=index_n < KV_LEN) + + dk *= SM_SCALE + mask = index_n < KV_LEN + {{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM) + offs_v = tl.arange(0, V_HEAD_DIM) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2 + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2 + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + if IS_DIVISIBLE: + kT = tl.load(kT_ptrs) + else: + kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + if CHECK_BLOCK_BOUNDARY: + m = offs_m2[:, None] % Q_LEN + n = offs_n2[None, :] % KV_LEN + else: + m = offs_m2[:, None] + n = offs_n2[None, :] + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf")) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + if IS_DIVISIBLE: + vT = tl.load(vT_ptrs) + else: + vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN) + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf")) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM) + offs_v = tl.arange(0, V_HEAD_DIM) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1 + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1 + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + qT = tl.load(qT_ptrs) + lse = tl.load(LSE + offs_m1) + else: + qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN) + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + if CHECK_BLOCK_BOUNDARY: + m = offs_m1[None, :] % Q_LEN + n = offs_n1[:, None] % KV_LEN + else: + m = offs_m1[None, :] + n = offs_n1[:, None] + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf")) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + if IS_DIVISIBLE: + do = tl.load(do_ptrs) + else: + do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf")) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + """ + + compute_next_offset_func, +) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_KV_BLOCK_SIZE, + SPARSE_Q_BLOCK_SIZE, + mask_graph, + ) = block_mask + + ( + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + device = query.get_device() + dtype = query.get_dtype() + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert Bq == Bkv, "Batch dimension must match" + B = Bq + + kernel_options = dict(kernel_options) + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) + for name, dtype in [ + ("score", dtype), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + fw_subgraph_buffer = build_subgraph_buffer( + fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph + ) + + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("grad_score_mod", dtype, device) + ] + joint_subgraph_buffer, *_ = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), joint_graph + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + layout_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key.get_size(), + key.get_stride(), + ) + + # Create delta which will is needed for the bwd's kernel + grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2)) + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + delta = lowerings[aten.sub](delta, grad_lse_exp2) + delta = ExternKernel.require_contiguous(delta) + + grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta]) + + # see NOTE:[TritonTemplates with multiple outputs] + grad_query = empty_strided( + query.get_size(), query.get_stride(), dtype=dtype, device=device + ) + grad_value = empty_strided( + value.get_size(), value.get_stride(), dtype=dtype, device=device + ) + + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA factor + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed. + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = ( + empty(0, device=query.get_device()) for _ in range(4) + ) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim) + + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_bwd(query)) + if config.max_autotune: + configs.extend( + [ + (BLOCK1, BLOCK2, w, s) + for BLOCK1 in [32, 64] + for BLOCK2 in [32, 64, 128] + for w in [4, 8] + for s in [1, 3, 4, 5] + if BLOCK2 % BLOCK1 == 0 + ] + ) + + for BLOCK1, BLOCK2, num_warps, num_stages in configs: + if ( + SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0 + or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0 + or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0 + or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 + ): + continue + + # Performance tuning + kernel_options.setdefault("BLOCK_M1", BLOCK1) + kernel_options.setdefault("BLOCK_N1", BLOCK2) + kernel_options.setdefault("BLOCK_M2", BLOCK2) + kernel_options.setdefault("BLOCK_N2", BLOCK1) + # Blocksparse options + kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_k, # We use store_output only for grad_key + subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer], + mutated_inputs=[grad_query, grad_value], + call_sizes=query.get_size() + key.get_size()[1:3], + num_stages=num_stages, + num_warps=num_warps, + **kernel_options, + ) + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks + 9: create_indices_fake, + 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks + 11: create_indices_fake, + 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks + 13: create_indices_fake, + 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks + 15: create_indices_fake, + } + + grad_key = autotune_select_algorithm( + "flex_attention_backward", + choices, + inputs_for_autotuning, + layout_k, + input_gen_fns=input_gen_fns, + ) + return ( + grad_query, + grad_key, + grad_value, + ) diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py b/.venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1ed5dbee515716648402360cacb01c36cd79a0 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py @@ -0,0 +1,570 @@ +# mypy: allow-untyped-defs +""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" +from typing import Any, List, Tuple + +import sympy + +import torch +from torch._inductor.virtualized import V + +from .. import config, ir +from ..ir import FixedLayout, FlexibleLayout +from ..lowering import empty, empty_strided, lowerings +from ..runtime.runtime_utils import is_power_of_2, next_power_of_2 +from ..select_algorithm import autotune_select_algorithm, TritonTemplate +from .flex_attention import ( + compute_forward_block_mn, + compute_forward_inner, + compute_next_offset_func, + create_indices_fake, + create_num_blocks_fake_generator, + maybe_realize, +) + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * kv_heads, SPLIT_KV, 1) + Each block is responsible for iterating over blocks of keys and values calculating + the local output for their tile of keys and values over all full length of query. + groups of SPLIT_KV blocks then combine their output to produce the final result. + """ + + return (batch_size * kv_heads, meta["SPLIT_KV"], 1) + + +flex_decoding_template = TritonTemplate( + name="flex_decoding", + grid=flex_decoding_grid, + source=r""" + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_z * stride_kz + off_hkv * stride_kh + v_offset = off_z * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + # TODO: support masks not broadcasted along the head dimension. + tl.device_assert(SPARSE_HQ == 1) + sparse_idx_h = 0 + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM) + offs_vd = tl.arange(0, V_HEAD_DIM) + + # KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if SAFE_M_BOUNDARY: + q = tl.load(Q + q_offset + q_range) + else: + mask = off_m[None, :, None] < Q_LEN + q = tl.load(Q + q_offset + q_range, mask) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + """ + + compute_forward_inner + + compute_next_offset_func + + compute_forward_block_mn, +) + + +def get_split_k(B: int, H: int, Mk: int, SM: int = 128) -> int: + """Heuristic for the number of splits from xformer""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = SM // bh # Each SM should at least get one block. + split_k = max(split_k, 1) + + return split_k + + +def _get_decoding_default_config(key) -> Tuple[int, int, int]: + dtype = key.get_dtype() + head_dim = key.get_size()[-1] + sm_version = torch.cuda.get_device_capability() + default_config = (64, 2, 1) + if sm_version >= (9, 0): + if head_dim > 128 and dtype == torch.float32: + return default_config + return (64, 2, 3) + return default_config + + +def create_flex_decoding_kernel(*args, **kwargs): + ( + query, + key, + value, + block_mask, + scale, + kernel_options, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, # full_kv_num_blocks, + full_kv_indices, # full_kv_indices, + _, # q_num_blocks + _, # q_indices + _, # full_q_num_blocks, + _, # full_q_indices, + SPARSE_KV_BLOCK_SIZE, + _, # SPARSE_Q_BLOCK_SIZE, + _, + ) = block_mask + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert Bq == Bkv, "Batch dimension must match" + B = Bq + kernel_options = dict(kernel_options) + + # TODO: Fix flex decoding non-divisible case! + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # Calculate GQA head sharing + gqa_shared_heads = Hq // Hkv + if not is_power_of_2(gqa_shared_heads): + raise ValueError( + "Number of shared query heads sharing the same KV head must be power of 2. " + ) + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + ) + + choices: List[Any] = [] + configs: List[Tuple[int, int, int]] = [] + configs.append(_get_decoding_default_config(key)) + # Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops. + if config.max_autotune: + configs += [ + (64, 2, 2), + (32, 2, 3), + (128, 2, 3), + ] + # TODO: fix autotuning. + + kernel_options.setdefault("SM_SCALE", scale) + kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv)) + MAX_SPLIT_KV = kernel_options["SPLIT_KV"] + + # create config dependent intermediate buffers + buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim] + buf_ML_shape = buf_ACC_shape[:-1] + buf_M = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + buf_L = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + + layout_acc = FixedLayout( + query.get_device(), + torch.float32, + buf_ACC_shape, + FlexibleLayout.contiguous_strides(buf_ACC_shape), + ) + + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim) + + kernel_options.setdefault( + "BLOCK_M", + ( + # m + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0)) + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin + max( + next_power_of_2( + V.graph.sizevars.size_hint( + seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + * gqa_shared_heads + ), + 16, + ) + ), + ) + + query = ir.ExternKernel.realize_input(query) + stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride() + + # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D] + gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim) + gqa_query_stride = ( + stride_b, + stride_hq * gqa_shared_heads, + stride_hq, + stride_seq_len_q, + stride_qk_head_dim, + ) + query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) + + V.graph.sizevars.guard_leq( + seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) + ) + + kernel_options.setdefault( + "SAFE_M_BOUNDARY", + ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + ) + # TODO: This feels sketchy + kernel_options.setdefault("SAFE_N_BOUNDARY", True) + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + for BLOCK_N, num_warps, num_stages in configs: + if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0: + continue + + # Performance tuning + kernel_options.setdefault("BLOCK_N", BLOCK_N) + kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # Work around https://github.com/pytorch/pytorch/issues/129625 + if num_stages == 2: + continue + flex_decoding_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout_acc, + subgraphs=[ + score_mod_subgraph, + mask_mod_subgraph, + ], + mutated_inputs=[buf_M, buf_L], + num_stages=num_stages, + num_warps=num_warps, + call_sizes=query.get_size(), + **kernel_options, + ) + + inputs_for_flex_decoding = ( + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + buf_ACC = autotune_select_algorithm( + "flex_decoding", + choices, + inputs_for_flex_decoding, + layout_acc, + input_gen_fns=input_gen_fns, + ) + + # Reduction + + g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) + alpha = lowerings[aten.exp2](adj_M) + + buf_L = lowerings[aten.mul](buf_L, alpha) + g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) + logsumexp = lowerings[aten.log2](g_L) + logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1)) + + alpha_unseq = lowerings[aten.unsqueeze](alpha, 4) + buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq) + output = lowerings[aten.sum](buf_ACC, axis=1) + L_unseq = lowerings[aten.unsqueeze](g_L, 3) + output = lowerings[aten.div](output, L_unseq) + output = lowerings[prims.convert_element_type](output, query.get_dtype()) + + return ( + output, + logsumexp, + ) diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/mm.py b/.venv/Lib/site-packages/torch/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..5741898bd06979ca9979d32218ee88848e201585 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/mm.py @@ -0,0 +1,776 @@ +# mypy: allow-untyped-defs +import functools +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + context_add_strides, + context_add_using_tf32, + get_mixedmm_precondition, + mixed_mm_operations, + mm_operations, +) +from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate +from torch._inductor.virtualized import V + +from .. import config as inductor_config +from ..codegen.common import BackendFeature +from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from ..codegen.wrapper import WrapperCodeGen +from ..ir import FlexibleLayout, is_triton +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + NoValidChoicesError, + TritonTemplate, +) +from ..utils import ( + get_gpu_shared_memory, + use_aten_gemm_kernels, + use_ck_template, + use_cpp_packed_gemm_template, + use_cutlass_template, + use_max_autotune, + use_triton_template, +) +from .mm_common import ( + addmm_epilogue, + extra_mm_configs, + int8_mm_configs, + mixed_mm_configs, + mm_args, + mm_configs, + mm_grid, + mm_options, + triton_config, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + +mm_template = TritonTemplate( + name="mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if B_PROLOGUE_CAST_TYPE is not None: + b = b.to(B_PROLOGUE_CAST_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out") + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.default +) + +aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm") + +aten__sparse_semi_structured_mm = ExternKernelChoice( + torch._sparse_semi_structured_mm, + "at::_sparse_semi_structured_mm", + has_out_variant=False, +) + + +def _is_int8_mat(mat): + return mat.get_dtype() in (torch.int8, torch.uint8) + + +def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): + """ + Giving torch.addmm a 1D tensor calls a different (faster) cublasLt + kernel under the hood. There are a few shapes where this is slower, + but they are rare. + """ + if inp.stride(0) == 0 or inp.size(0) == 1: + return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) + return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) + + +aten_bias_addmm = ExternKernelChoice(bias_addmm, None) + + +@register_lowering(aten.mm, type_promotion_kind=None) +def tuned_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + name = "mm" + + aten_layout = layout + if not use_max_autotune(): + aten_layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + + # options to tune from + choices = ( + [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] + ) + static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout) + if is_nonzero and use_triton_template(layout): + for config in mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + + if is_nonzero and use_ck_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + + input_nodes = [mat1, mat2] + if ( + is_nonzero + and use_triton_template(layout) + and torch._inductor.config.run_autoheuristic(name) + and is_triton(mat1) + ): + always_included = [] + if use_aten_gemm_kernels(): + always_included.append("extern_mm") + num_choices_before_extra_configs = len(choices) + for config in extra_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + # using AutoHeuristic for ranking + ah_choices = mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + mm_operations(), + None, + top_k=10, + always_included=always_included, + ) + if not torch._inductor.config.collect_autoheuristic(name): + # if we are collecting data, we do not want to modify choices + if ah_choices is not None and len(ah_choices) > 0: + # the order in which autoheuristic returns choices is not the same as + # as the order of choices, which affects things like epilogue fusion. + # once epilogue fusion benchmarks choices in sorted order, I think we can + # just use the order returned by autoheuristic + choices = [choice for choice in choices if choice in ah_choices] + else: + choices = choices[:num_choices_before_extra_configs] + + if ( + len(choices) == 0 + and not use_aten_gemm_kernels() + and inductor_config.autotune_fallback_to_aten + ): + log.warning("No choices for GEMM, using ATen backend as fallback") + return aten_mm.bind((mat1, mat2), aten_layout).output_node() + + try: + return autotune_select_algorithm(name, choices, [mat1, mat2], layout) + except NoValidChoicesError: + if not inductor_config.autotune_fallback_to_aten: + raise + log.warning("All choices for GEMM were invalid, using ATen backend as fallback") + return aten_mm.bind((mat1, mat2), aten_layout).output_node() + + +def _is_static_problem(inputs_tensors, layout): + # checks whether all input tensors and the output layout + # have a static shape by attempting to convert the dimensions + # to int + static_shape = True + static_size = WrapperCodeGen.statically_known_list_of_ints_or_none(layout.size) + if static_size is None: + nonzero = True + for s in layout.size: + sz = WrapperCodeGen.statically_known_int_or_none(s) + if sz is not None and sz == 0: + nonzero = False + break + return False, nonzero + numel = 1 + for dim in static_size: + numel *= dim + nonzero = numel > 0 + return static_shape, nonzero + + +@register_lowering(aten._int_mm, type_promotion_kind=None) +def tuned_int_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout) + use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) + + choices = ( + [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + + # TODO: Re-enable eager mode implementation once cuBLAS is fixed + if use_cutlass or use_triton_template(layout, enable_int32=True): + choices = [] + + if use_cutlass: + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + if is_nonzero and use_triton_template(layout, enable_int32=True): + for config in int8_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + if len(choices) == 0: + log.warning( + "No choices for integer GEMM avaialbe using configured backends, using ATen backend as fallback" + ) + choices = [aten__int_mm.bind((mat1, mat2), layout)] + + try: + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + except NoValidChoicesError: + if not inductor_config.autotune_fallback_to_aten: + raise + log.warning("All choices for GEMM were invalid, using ATen backend as fallback") + choices = [aten__int_mm.bind((mat1, mat2), layout)] + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten.addmm, type_promotion_kind=None) +def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + ordered_kwargs_for_cpp_kernel = ("beta", "alpha") + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + static_shape, is_nonzero = _is_static_problem([inp, mat1, mat2], layout) + if (not is_nonzero) or (not use_max_autotune()): + # Use a FlexibleLayout if we are not autotuning. + # This allows padding strides for the output. + from torch._inductor.ir import FixedLayout, FlexibleLayout + + if isinstance(layout, FixedLayout): + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + use_aten_gemm_kernels() + and inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + + if is_nonzero and use_triton_template(layout): + for config in mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): + # Filter out a known cause of CUDA illegal memory access errors + # broadcasting on the last dim of the bias term seems not to be working + # in the linear GEMM epilogue used by addmm. + if ( + WrapperCodeGen.statically_known_int_or_none(inp_expanded.layout.stride[-1]) + != 0 + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + ) + + if is_nonzero and use_ck_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + ) + + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + has_bias=True, + ) + + add_aten_fallback = False + if len(choices) == 0: + log.warning("No choices for GEMM, using ATen backend as fallback") + add_aten_fallback = True + + if add_aten_fallback: + choices.append( + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + ordered_kwargs_for_cpp_kernel, + alpha=alpha, + beta=beta, + ) + ) + + if ( + inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + try: + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) + except NoValidChoicesError: + if not inductor_config.autotune_fallback_to_aten: + raise + log.warning("All choices for GEMM were invalid, using ATen backend as fallback") + fallback_choice = aten_addmm.bind( + (inp, mat1, mat2), + layout, + ordered_kwargs_for_cpp_kernel, + alpha=alpha, + beta=beta, + ) + return fallback_choice.output_node() + + +@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None) +def tuned_sparse_semi_structured_mm( + mat1, mat1_meta, mat2, *, out_dtype=None, layout=None +): + from torch._inductor.select_algorithm import realize_inputs + + mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2) + m1, k1 = mat1.get_size() + m2, _ = mat1_meta.get_size() + k2, n = mat2.get_size() + m = V.graph.sizevars.guard_equals(m1, m2) + k = V.graph.sizevars.guard_equals(2 * k1, k2) + + if layout is None: + from torch._inductor.ir import FixedLayout + + layout = FixedLayout( + mat2.get_device(), + out_dtype if out_dtype else mat2.get_dtype(), + [m, n], + [n, 1], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + choices = ( + [ + aten__sparse_semi_structured_mm.bind( + (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if m * n != 0 and use_cutlass_template(layout, m, n, k): + CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True + ) + + return autotune_select_algorithm( + "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout + ) + + +def fallback_mixed_mm(mat1, mat2, *, out): + return torch.mm(mat1, mat2.to(mat1.dtype), out=out) + + +aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None) + + +@functools.lru_cache(None) +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: + props = torch.cuda.get_device_properties(index or 0) + return props.major <= 7 + + +def dims_are_int(dims): + return all(isinstance(dim, int) for dim in dims) + + +def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout): + m, n, k = get_size_hints(mat1, mat2, m, n, k) + if not dims_are_int([m, n, k]): + return None + + if mat1.dtype != torch.float16: + return None + + # only use heuristic if we are running on an A100 + # torch.cuda.get_device_capability() >= (8, 0) returns true for A10G + # which does not have enough shared memory for one of the configs + if ( + not torch.cuda.get_device_capability() >= (8, 0) + ) or get_gpu_shared_memory() != 166912: + return None + + if m == 1 and (n % 16 != 0 or k % 16 != 0): + return None + + if m <= 16 and n >= 4096 and k >= 4096: + return triton_config( + BLOCK_M=16, + BLOCK_N=64, + BLOCK_K=128, + num_stages=5, + num_warps=4, + ) + elif m > 16 and m <= 32 and n >= 4096 and k >= 4096: + return triton_config( + BLOCK_M=32, + BLOCK_N=32, + BLOCK_K=128, + num_stages=5, + num_warps=4, + ) + elif m > 32 and m <= 64 and n >= 4096 and k >= 4096: + return triton_config( + BLOCK_M=64, + BLOCK_N=32, + BLOCK_K=128, + num_stages=5, + num_warps=4, + ) + return None + + +def mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + ops, + precondition, + top_k: Optional[int] = None, + always_included=None, +): + m, n, k = get_size_hints(mat1, mat2, m, n, k) + if not dims_are_int([m, n, k]): + return None + mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2) + + def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride): + context = AHContext() + context.add_feature("m", m) + context.add_feature("k", k) + context.add_feature("n", n) + context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True) + context_add_strides(context, "mat1", mat1_stride) + context_add_strides(context, "mat2", mat2_stride) + context.add_feature( + "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True + ) + context.add_feature( + "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True + ) + if name == "mm": + # for mixed_mm, we only consider fp16 + context_add_using_tf32(context, mat1.layout.dtype) + return context + + def fallback(): + return None + + context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride) + autoheuristic = AutoHeuristicSelectAlgorithm( + fallback=fallback, + choices=choices, + input_nodes=input_nodes, + context=context, + name=name, + augment_context=ops, + precondition=precondition, + ) + + if top_k is not None: + # TODO: is there a cleaner way to ensure aten.mm is always included? + return autoheuristic.get_top_k_choices_caller( + top_k, always_included=always_included + ) + + return autoheuristic.get_choice_caller() + + +def get_size_hints(mat1, mat2, m, n, k): + if not isinstance(m, int) or not isinstance(k, int): + (m, k) = V.graph.sizevars.size_hints( + mat1.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + + if not isinstance(n, int) or not isinstance(k, int): + (k, n) = V.graph.sizevars.size_hints( + mat2.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + return m, n, k + + +def get_size_hints_strides(mat1, mat2): + mat1_stride = mat1.layout.stride + mat2_stride = mat2.layout.stride + strides = [mat1_stride, mat2_stride] + strides_hints = [] + for stride in strides: + if not isinstance(stride, int): + stride = V.graph.sizevars.size_hints( + stride, + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + strides_hints.append(stride) + return strides_hints[0], strides_hints[1] + + +def tuned_mixed_mm(mat1, mat2, mat2_dtype): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None) + static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout) + + fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout) + + choices = [fallback] + + # can't use triton kernel unless one of these is true or if running on v100 (numerical issues) + skip_triton = ( + ( + mat1.layout.dtype != torch.float32 + and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed()) + ) + or _is_sm7x_or_older_gpu(layout.device.index) + or inductor_config.mixed_mm_choice == "aten" + or not V.graph.has_feature(layout.device, BackendFeature.TRITON_TEMPLATES) + or ( + mat1.layout.dtype == torch.float32 and torch.backends.cuda.matmul.allow_tf32 + ) + or (mat1.layout.dtype == torch.bfloat16 and mat2.layout.dtype == torch.uint8) + ) + + if inductor_config.mixed_mm_choice == "triton": + choices = [] + + if not skip_triton: + b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") + if static_shape and inductor_config.mixed_mm_choice == "heuristic": + choices = [] + config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout) + if config is not None: + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout, b_prologue_cast_type), + ) + choices.append(fallback) + + has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) + for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout, b_prologue_cast_type), + ) + + if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + + if skip_triton and not choices: + choices = [fallback] + + name = "mixed_mm" + input_nodes = [mat1, mat2] + if torch._inductor.config.run_autoheuristic(name): + choice = mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + mixed_mm_operations(), + get_mixedmm_precondition, + ) + if ( + not skip_triton + and inductor_config.mixed_mm_choice == "heuristic" + and choice is not None + ): + choices.insert(0, choice) + return autotune_select_algorithm(name, choices, input_nodes, layout) + + +# This op is a special case of the int_mm op which we use based on the pattern +# _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent +# realization of the int32 _int_mm output by forcing fusion with the mul op. +# This is only used when config.force_fuse_int_mm_with_mul = True +def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None): + out_dtype = ( + torch.promote_types(mat3.get_dtype(), torch.int32) + if out_dtype is None + else out_dtype + ) + m, n, k, layout, mat1, mat2, mat3 = mm_args( + mat1, mat2, mat3, layout=layout, out_dtype=out_dtype + ) + choices: List[Dict[Any, Any]] = [] + for config in int8_mm_configs(m, n, k): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3), + layout=layout, + **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"), + suffix_args=1, + epilogue_fn=V.ops.mul, + ) + return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout) diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py b/.venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..088f80ff3195274e19d4eed9cb5b767a06fa98d4 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py @@ -0,0 +1,466 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from typing import cast, List, Tuple + +import sympy + +import torch +from torch._inductor.select_algorithm import realize_inputs +from torch._inductor.virtualized import V + +from .. import config as inductor_config +from ..runtime.runtime_utils import next_power_of_2 +from ..utils import ceildiv as cdiv + + +log = logging.getLogger(__name__) + + +def triton_config(num_stages, num_warps, **kwargs): + from triton import Config + + return Config(kwargs, num_stages=num_stages, num_warps=num_warps) + + +def filtered_configs( + m: int, + n: int, + k: int, + configs: List[Tuple[int, int, int, int, int]], + has_int8_tensor=False, +): + """Heuristic to shrink configs when they are bigger than the input size""" + + min_block_size = 16 + # block_k=16 seems to be causing issues + # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424 + min_block_size_k = 32 if has_int8_tensor else 16 + m = max( + next_power_of_2( + V.graph.sizevars.size_hint( + m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + n = max( + next_power_of_2( + V.graph.sizevars.size_hint( + n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size, + ) + k = max( + next_power_of_2( + V.graph.sizevars.size_hint( + k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] + ) + ), + min_block_size_k, + ) + used = set() + for block_m, block_n, block_k, num_stages, num_warps in configs: + # shrink configs for small sizes + block_m = max(min(block_m, m), min_block_size) + block_n = max(min(block_n, n), min_block_size) + block_k = max(min(block_k, k), min_block_size_k) + # each warp computes 16x16 tile = 256 + num_warps = min(num_warps, block_m * block_n // 256) + if torch.version.hip: + for matrix_instr_nonkdim in [0, 16]: + if matrix_instr_nonkdim != 0 and ( + block_m % matrix_instr_nonkdim != 0 + or block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + if ( + block_m, + block_n, + block_k, + num_stages, + num_warps, + matrix_instr_nonkdim, + ) not in used: + used.add( + ( + block_m, + block_n, + block_k, + num_stages, + num_warps, + matrix_instr_nonkdim, + ) + ) + yield triton_config( + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_stages=num_stages, + num_warps=num_warps, + matrix_instr_nonkdim=matrix_instr_nonkdim, + ) + else: + if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used: + used.add((block_m, block_n, block_k, num_stages, num_warps, 0)) + yield triton_config( + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_stages=num_stages, + num_warps=num_warps, + ) + + +# List of dictionaries to store the kernel configs. Configs that evaluate to true +# will be utilised on the target platform. The configs are as follows: +# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) +mm_kernel_configs = ( + [ + {"config": (32, 32, 16, 1, 2), "cond": True}, + {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (64, 32, 128, 5, 4), "cond": True}, + {"config": (64, 64, 16, 2, 4), "cond": True}, + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + {"config": (64, 64, 128, 5, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (64, 128, 64, 3, 4), "cond": True}, + {"config": (64, 128, 128, 4, 4), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (128, 128, 32, 3, 4), "cond": True}, + {"config": (128, 128, 64, 3, 4), "cond": True}, + {"config": (128, 128, 64, 5, 8), "cond": True}, + ] + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else [ + {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True} + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5] + for num_warps in [2, 4, 8] + ] +) + +# these are only used in tuned_mm when AutoHeuristic is enabled +# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned +# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 +# which saves compilation time (since less configs are autotuned) and potentially increase performance +# because the learned heuristic might predict a config that is not part mm_configs +extra_mm_kernel_configs = [ + {"config": (16, 32, 16, 3, 2), "cond": True}, + {"config": (16, 32, 32, 4, 2), "cond": True}, + {"config": (16, 32, 32, 5, 2), "cond": True}, + {"config": (64, 64, 128, 3, 4), "cond": True}, + {"config": (128, 64, 32, 2, 2), "cond": True}, + {"config": (128, 64, 64, 3, 8), "cond": True}, + {"config": (128, 64, 128, 4, 8), "cond": True}, + {"config": (128, 128, 32, 4, 4), "cond": True}, + {"config": (128, 128, 64, 3, 8), "cond": True}, + {"config": (128, 128, 64, 5, 4), "cond": True}, +] + +int8_mm_kernel_configs = [ + {"config": (64, 64, 32, 2, 4), "cond": True}, + {"config": (64, 128, 32, 3, 4), "cond": True}, + {"config": (128, 64, 32, 3, 4), "cond": True}, + {"config": (64, 128, 32, 4, 8), "cond": True}, + {"config": (128, 64, 32, 4, 8), "cond": True}, + {"config": (64, 32, 32, 5, 8), "cond": True}, + {"config": (32, 64, 32, 5, 8), "cond": True}, + {"config": (128, 128, 32, 2, 8), "cond": True}, + {"config": (64, 64, 64, 3, 8), "cond": True}, + # {"config": (32, 32, 128, 2, 4), "cond": True}, + # {"config": (64, 64, 16, 2, 4), "cond": True}, + # {"config": (32, 32, 16, 1, 2), "cond": True}, + {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None}, + {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None}, +] + +# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192). +mixed_mm_kernel_configs_small_m = [ + {"config": (16, 128, 256, 3, 4), "cond": True}, + {"config": (16, 128, 256, 5, 8), "cond": True}, +] + +mixed_mm_kernel_configs = ( + mm_kernel_configs + mixed_mm_kernel_configs_small_m + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" + else mm_kernel_configs +) + +scaled_mm_kernel_configs = [ + {"config": (128, 256, 32, 3, 8), "cond": True}, + {"config": (256, 128, 32, 3, 8), "cond": True}, + {"config": (256, 64, 32, 4, 4), "cond": True}, + {"config": (64, 256, 32, 4, 4), "cond": True}, + {"config": (128, 128, 32, 4, 4), "cond": True}, + {"config": (128, 64, 32, 4, 4), "cond": True}, + {"config": (64, 128, 32, 4, 4), "cond": True}, + {"config": (128, 32, 32, 4, 4), "cond": True}, + {"config": (64, 32, 32, 5, 2), "cond": True}, + {"config": (256, 128, 128, 3, 8), "cond": True}, + {"config": (256, 64, 128, 4, 4), "cond": True}, + {"config": (64, 256, 128, 4, 4), "cond": True}, + {"config": (128, 128, 128, 4, 4), "cond": True}, + {"config": (128, 64, 64, 4, 4), "cond": True}, + {"config": (64, 128, 64, 4, 4), "cond": True}, + {"config": (128, 32, 64, 4, 4), "cond": True}, + {"config": (64, 32, 64, 5, 2), "cond": True}, + {"config": (16, 32, 32, 2, 2), "cond": True}, + {"config": (16, 64, 32, 2, 2), "cond": True}, + {"config": (16, 128, 32, 2, 4), "cond": True}, + {"config": (16, 256, 32, 2, 4), "cond": True}, + {"config": (16, 32, 64, 2, 2), "cond": True}, + {"config": (16, 64, 64, 2, 2), "cond": True}, + {"config": (16, 128, 64, 2, 4), "cond": True}, + {"config": (16, 256, 64, 2, 4), "cond": True}, + {"config": (32, 32, 32, 2, 2), "cond": True}, + {"config": (32, 64, 32, 2, 2), "cond": True}, + {"config": (32, 128, 32, 2, 4), "cond": True}, + {"config": (32, 256, 32, 2, 4), "cond": True}, + {"config": (32, 32, 64, 2, 2), "cond": True}, + {"config": (32, 64, 64, 2, 2), "cond": True}, + {"config": (32, 128, 64, 2, 4), "cond": True}, + {"config": (32, 256, 64, 2, 4), "cond": True}, + {"config": (16, 32, 32, 3, 2), "cond": True}, + {"config": (16, 64, 32, 3, 2), "cond": True}, + {"config": (16, 128, 32, 3, 4), "cond": True}, + {"config": (16, 256, 32, 3, 4), "cond": True}, + {"config": (16, 32, 64, 3, 2), "cond": True}, + {"config": (16, 64, 64, 3, 2), "cond": True}, + {"config": (16, 128, 64, 3, 4), "cond": True}, + {"config": (16, 256, 64, 3, 4), "cond": True}, + {"config": (32, 32, 32, 3, 2), "cond": True}, + {"config": (32, 64, 32, 3, 2), "cond": True}, + {"config": (32, 128, 32, 3, 4), "cond": True}, + {"config": (32, 256, 32, 3, 4), "cond": True}, + {"config": (32, 32, 64, 3, 2), "cond": True}, + {"config": (32, 64, 64, 3, 2), "cond": True}, + {"config": (32, 128, 64, 3, 4), "cond": True}, + {"config": (32, 256, 64, 3, 4), "cond": True}, + {"config": (16, 32, 32, 4, 2), "cond": True}, + {"config": (16, 64, 32, 4, 2), "cond": True}, + {"config": (16, 128, 32, 4, 4), "cond": True}, + {"config": (16, 256, 32, 4, 4), "cond": True}, + {"config": (16, 32, 64, 4, 2), "cond": True}, + {"config": (16, 64, 64, 4, 2), "cond": True}, + {"config": (16, 128, 64, 4, 4), "cond": True}, + {"config": (16, 256, 64, 4, 4), "cond": True}, + {"config": (32, 32, 32, 4, 2), "cond": True}, + {"config": (32, 64, 32, 4, 2), "cond": True}, + {"config": (32, 128, 32, 4, 4), "cond": True}, + {"config": (32, 256, 32, 4, 4), "cond": True}, + {"config": (32, 32, 64, 4, 2), "cond": True}, + {"config": (32, 64, 64, 4, 2), "cond": True}, + {"config": (32, 128, 64, 4, 4), "cond": True}, + {"config": (32, 256, 64, 4, 4), "cond": True}, + {"config": (16, 32, 32, 5, 2), "cond": True}, + {"config": (16, 64, 32, 5, 2), "cond": True}, + {"config": (16, 128, 32, 5, 4), "cond": True}, + {"config": (16, 256, 32, 5, 4), "cond": True}, + {"config": (16, 32, 64, 5, 2), "cond": True}, + {"config": (16, 64, 64, 5, 2), "cond": True}, + {"config": (16, 128, 64, 5, 4), "cond": True}, + {"config": (16, 256, 64, 5, 4), "cond": True}, + {"config": (32, 32, 32, 5, 2), "cond": True}, + {"config": (32, 64, 32, 5, 2), "cond": True}, + {"config": (32, 128, 32, 5, 4), "cond": True}, + {"config": (32, 256, 32, 5, 4), "cond": True}, + {"config": (32, 32, 64, 5, 2), "cond": True}, + {"config": (32, 64, 64, 5, 2), "cond": True}, + {"config": (32, 128, 64, 5, 4), "cond": True}, + {"config": (32, 256, 64, 5, 4), "cond": True}, + {"config": (16, 32, 32, 6, 2), "cond": True}, + {"config": (16, 64, 32, 6, 2), "cond": True}, + {"config": (16, 128, 32, 6, 4), "cond": True}, + {"config": (16, 256, 32, 6, 4), "cond": True}, + {"config": (16, 32, 64, 6, 2), "cond": True}, + {"config": (16, 64, 64, 6, 2), "cond": True}, + {"config": (16, 128, 64, 6, 4), "cond": True}, + {"config": (16, 256, 64, 6, 4), "cond": True}, + {"config": (32, 32, 32, 6, 2), "cond": True}, + {"config": (32, 64, 32, 6, 2), "cond": True}, + {"config": (32, 128, 32, 6, 4), "cond": True}, + {"config": (32, 256, 32, 6, 4), "cond": True}, + {"config": (32, 32, 64, 6, 2), "cond": True}, + {"config": (32, 64, 64, 6, 2), "cond": True}, + {"config": (32, 128, 64, 6, 4), "cond": True}, + {"config": (32, 256, 64, 6, 4), "cond": True}, +] + + +# Create filtered list of configs based on cond evaluation +mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mm_kernel_configs + if config["cond"] +) +extra_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in extra_mm_kernel_configs + if config["cond"] +) +int8_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in int8_mm_kernel_configs + if config["cond"] +) +mixed_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in mixed_mm_kernel_configs + if config["cond"] +) +scaled_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in scaled_mm_kernel_configs + if config["cond"] +) + +# On ROCm convert num_stages to 0 to enable software pipelining +if torch.version.hip: + mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in mm_platform_configs + ) + extra_mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in extra_mm_platform_configs + ) + int8_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in mm_platform_configs + ) + mixed_mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in mixed_mm_platform_configs + ) + scaled_mm_platform_configs = tuple( + (config[0], config[1], config[2], 0, config[4]) + for config in scaled_mm_platform_configs + ) + +mm_configs = functools.partial( + filtered_configs, + configs=mm_platform_configs, +) + +extra_mm_configs = functools.partial( + filtered_configs, + configs=extra_mm_platform_configs, +) + +int8_mm_configs = functools.partial( + filtered_configs, + configs=int8_platform_configs, +) + +mixed_mm_configs = functools.partial( + filtered_configs, + configs=mixed_mm_platform_configs, +) + +scaled_mm_configs = functools.partial( + filtered_configs, + configs=scaled_mm_platform_configs, +) + + +def mm_grid(m, n, meta): + """ + The CUDA grid size for matmul triton templates. + """ + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) + + +def acc_type(dtype): + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None): + """ + Common options to matmul triton templates. + """ + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) + == config.kwargs["BLOCK_K"] + ) + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) + ) + return dict( + GROUP_M=8, + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + ACC_TYPE=acc_type(layout.dtype), + B_PROLOGUE_CAST_TYPE=b_prologue_cast_type, + num_stages=config.num_stages, + num_warps=config.num_warps, + **config.kwargs, + ) + + +def mm_args( + mat1, + mat2, + *others, + layout=None, + out_dtype=None, + use_4x2_dim=False, + mat2_transposed=False, +): + """ + Common arg processing for mm,bmm,addmm,etc + """ + mat1, mat2 = realize_inputs(mat1, mat2) + *b1, m, k1 = mat1.get_size() + if mat2_transposed: + *b2, n, k2 = mat2.get_size() + else: + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.guard_equals(k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + layout = FixedLayout( + mat1.get_device(), + out_dtype, + [*b, m, n], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + from ..lowering import expand + + others = [realize_inputs(expand(x, layout.size)) for x in others] + + return [m, n, k, layout, mat1, mat2, *others] + + +def addmm_epilogue(dtype, alpha, beta): + def epilogue(acc, bias): + if alpha != 1: + acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) + if beta != 1: + bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) + return V.ops.add(acc, bias) + + return epilogue diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/.venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..3c203610d15923ed764bffb69899ebc6e3c680c4 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,248 @@ +# mypy: allow-untyped-defs +import functools + +import torch + +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid, mm_options + + +aten = torch.ops.aten + +aten_mm_plus_mm = ExternKernelChoice( + torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm" +) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1)) + and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + + if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1)) + and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + for k2 in range(K1, 0, -BLOCK_K): + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +@functools.lru_cache(None) +def mm_configs(): + import triton + + # List of dictionaries to store the kernel configs. Configs that evaluate to true + # will be utilised on the target platform + mm_triton_configs = [ + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 2, + "num_warps": 4, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 3, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 16, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, + "num_stages": 4, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, + "num_stages": 1, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, + "num_stages": 1, + "num_warps": 8, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, + "num_stages": 1, + "num_warps": 8, + "cond": torch.version.hip is None, + }, + { + "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, + "num_stages": 2, + "num_warps": 4, + "cond": True, + }, + { + "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, + "num_stages": 1, + "num_warps": 2, + "cond": True, + }, + ] + + # Filter out configs in which cond evaluates to true + # On ROCm convert num_stages to 1 as pipelining provides no benefit + if torch.version.hip: + filtered_configs = [ + triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"]) + for c in mm_triton_configs + if c["cond"] + ] + else: + filtered_configs = [ + triton.Config( + c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"] + ) + for c in mm_triton_configs + if c["cond"] + ] + + return filtered_configs + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + # Optimization is optional, because we can always just not do the fusion + if ( + m1 * n1 == 0 + or m2 * n2 == 0 + or not V.graph.sizevars.statically_known_list_equals( + mat1.get_size(), mat3.get_size() + ) + or not V.graph.sizevars.statically_known_list_equals( + mat2.get_size(), mat4.get_size() + ) + ): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/openai/triton/issues/967 + return lowerings[aten.add]( + lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) + ) + + assert layout1 == layout2 + # options to tune from + choices = ( + [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + if use_aten_gemm_kernels() + else [] + ) + if use_triton_template(layout1): + for config in mm_configs(): + # see https://github.com/openai/triton/issues/1298 + # BLOCK_K = K causes llvm error + if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1): + mm_plus_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3, mat4), + layout=layout1, + **mm_options(config, m1, n1, k1, layout1), + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 + ) diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py b/.venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb2aa4c77d751e2036a51d44f456bb1e1003806 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py @@ -0,0 +1,311 @@ +import logging +from typing import Any, Dict, List, Optional, Tuple + +import sympy + +import torch + +from .. import config as inductor_config +from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox +from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + NoValidChoicesError, + realize_inputs, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from .mm import _is_static_problem # TODO(yangsiyu) move to mm_common +from .mm_common import mm_args, mm_grid, scaled_mm_configs + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +scaled_mm_template = TritonTemplate( + name="scaled_mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if B_PROLOGUE_CAST_TYPE is not None: + b = b.to(B_PROLOGUE_CAST_TYPE) + if USE_FAST_ACCUM: + acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) + else: + acc += tl.dot(a, b, out_dtype=ACC_TYPE) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + if SCALING_ROWWISE: + inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) + inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) + inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] + acc *= inv_scale_row + else: + # for tensor-wise scaling, the scales are scalars + inv_a_scale = tl.load(A_inverse_scale) + inv_b_scale = tl.load(B_inverse_scale) + inv_scale = inv_a_scale * inv_b_scale + acc *= inv_scale + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +# Inductor does not allow optional tensor input arguments currently (pass None as an +# input node to template choices), but since for _scaled_mm there is only one such arg +# (bias), work around by having a second template when bias is provided. +scaled_mm_bias_template = TritonTemplate( + name="scaled_mm_bias", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if B_PROLOGUE_CAST_TYPE is not None: + b = b.to(B_PROLOGUE_CAST_TYPE) + if USE_FAST_ACCUM: + acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) + else: + acc += tl.dot(a, b, out_dtype=ACC_TYPE) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + if SCALING_ROWWISE: + inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) + inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) + inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] + acc *= inv_scale_row + else: + # for tensor-wise scaling, the scales are scalars + inv_a_scale = tl.load(A_inverse_scale) + inv_b_scale = tl.load(B_inverse_scale) + inv_scale = inv_a_scale * inv_b_scale + acc *= inv_scale + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # bias + bias = tl.load(bias_ptr + rn, mask=rn < N) + acc += bias + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm") + + +def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a: StorageBox, + scale_b: StorageBox, + use_fast_accum: bool, + b_prologue_cast_type: Optional[str] = None, +) -> Dict[str, Any]: + even_k_symbolic = ( + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] + ) + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + return dict( + GROUP_M=8, + EVEN_K=even_k_symbolic, + ACC_TYPE="tl.float32", + B_PROLOGUE_CAST_TYPE=b_prologue_cast_type, + USE_FAST_ACCUM=use_fast_accum, + num_stages=config.num_stages, + num_warps=config.num_warps, + # tensor-wise scaling if scalar scales + SCALING_ROWWISE=len(scale_a.get_size()) == 2, + **config.kwargs, + ) + + +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + scale_a, scale_b = realize_inputs(scale_a, scale_b) + + input_nodes: Tuple[Any, ...] + # workaround for Inductor not supporting optional tensor input arguments + if bias is None: + input_nodes = (mat_a, mat_b, scale_a, scale_b) + triton_template = scaled_mm_template + else: + bias = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a, scale_b, bias) + triton_template = scaled_mm_bias_template + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices: List[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + static_shape, is_nonzero = _is_static_problem([mat_a, mat_b], layout) + if is_nonzero and use_triton_template(layout, enable_float8=True): + for config in scaled_mm_configs(m, n, k): + if k == 16 and config.kwargs["BLOCK_M"] >= 64: + continue # Triton crashes in this case + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + triton_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + ) + + if ( + len(choices) == 0 + and not use_aten_gemm_kernels() + and inductor_config.autotune_fallback_to_aten + ): + log.warning("No choices for scaled_mm, using ATen backend as fallback") + return aten_choice.output_node() + + try: + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + except NoValidChoicesError: + if not inductor_config.autotune_fallback_to_aten: + raise + log.warning( + "All choices for scaled_mm were invalid, using ATen backend as fallback" + ) + return aten_choice.output_node() diff --git a/.venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py b/.venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..6acc01d950b5671e35424374169de9ec3329de3d --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py @@ -0,0 +1,87 @@ +# mypy: allow-untyped-defs +import logging +from typing import List, TYPE_CHECKING + +from ..select_algorithm import autotune_select_algorithm, TritonTemplate +from .mm_common import mm_args, mm_configs, mm_grid, mm_options + + +if TYPE_CHECKING: + from ..ir import ChoiceCaller + +log = logging.getLogger(__name__) + +uint4x2_mixed_mm_template = TritonTemplate( + name="uint4x2_mixed_mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn) + b_shifts = 4*(rk%2) + b_subs = 8*(1-(rk%2)) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + b = ((b >> b_shifts[:, None]) & 0xF) - 8 + b = b.to(B_PROLOGUE_CAST_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K//2 * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True) + choices: List[ChoiceCaller] = [] + b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") + for config in mm_configs(m, n, k): + uint4x2_mixed_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout, b_prologue_cast_type), + ) + return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout) diff --git a/.venv/Lib/site-packages/torch/_inductor/package/__init__.py b/.venv/Lib/site-packages/torch/_inductor/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..237540fb19ebc56e1f81934fc13d55da51771318 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/package/__init__.py @@ -0,0 +1 @@ +from .package import load_package, package_aoti diff --git a/.venv/Lib/site-packages/torch/_inductor/package/build_package.py b/.venv/Lib/site-packages/torch/_inductor/package/build_package.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a64dc1369f7b6e33ebc718363d8d031e7ebbe6 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/package/build_package.py @@ -0,0 +1,15 @@ +build_package_contents = """ +import os +from pathlib import Path + +from torch._inductor.package.package import compile_so + +curr_dir = Path(__file__).parent +aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(curr_dir) + for file in files +] + +output_so = compile_so(curr_dir, aoti_files, curr_dir) +""" diff --git a/.venv/Lib/site-packages/torch/_inductor/package/package.py b/.venv/Lib/site-packages/torch/_inductor/package/package.py new file mode 100644 index 0000000000000000000000000000000000000000..d014b4cffe5602236e229840ccdeed489dd06f37 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/package/package.py @@ -0,0 +1,237 @@ +import glob +import json +import os +import shlex +import subprocess +import tempfile +import zipfile +from pathlib import Path +from typing import Callable, List, Optional, Union + +import torch +import torch._inductor +import torch.utils._pytree as pytree +from torch._inductor import config, exc +from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder +from torch.export._tree_utils import reorder_kwargs + +from .build_package import build_package_contents +from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION + + +class PT2ArchiveWriter: + def __init__(self, archive_path: str) -> None: + self.archive_path: str = archive_path + self.archive_file: Optional[zipfile.ZipFile] = None + + def __enter__(self) -> "PT2ArchiveWriter": + assert self.archive_file is None + self.archive_file = zipfile.ZipFile( + self.archive_path, "w", compression=zipfile.ZIP_STORED + ) + self.writestr("version", str(ARCHIVE_VERSION)) + self.writestr("archive_format", "pt2") + return self + + def __exit__(self, *args) -> None: # type: ignore[no-untyped-def] + assert self.archive_file is not None + self.archive_file.close() + self.archive_file = None + return None + + def writestr(self, name: str, data: Union[bytes, str]) -> None: + assert self.archive_file is not None + self.archive_file.writestr(name, data) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert Path(file_path).is_file(), f"{file_path} is not a valid file path" + assert self.archive_file is not None + self.archive_file.write(file_path, arcname=name) + + +class PT2ArchiveReader: + def __init__(self, archive_path: str) -> None: + self.archive_path: str = archive_path + self.archive_file: Optional[zipfile.ZipFile] = None + + def __enter__(self) -> "PT2ArchiveReader": + self.archive_file = zipfile.ZipFile( + self.archive_path, "r", compression=zipfile.ZIP_STORED + ) + return self + + def __exit__(self, *args) -> None: # type: ignore[no-untyped-def] + if self.archive_file is not None: + self.archive_file.close() + return None + + def read(self, name: str) -> bytes: + assert self.archive_file is not None + return self.archive_file.read(name) + + def extract_to_path(self, member: str, path: str) -> str: + assert self.archive_file is not None + return self.archive_file.extract(member, path) + + def extractall(self, path: str) -> None: + assert self.archive_file is not None + self.archive_file.extractall(path) + + def get_file_names(self) -> List[str]: + assert self.archive_file is not None + return self.archive_file.namelist() + + +def _run_command_and_check(cmd: str) -> None: + cmd = shlex.split(cmd) + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + raise exc.CppCompileError(cmd, e.output) from e + + +def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str: + def get_aoti_file_with_suffix(suffix: str) -> str: + for file in aoti_files: + if file.endswith(suffix): + return file + raise RuntimeError(f"Unable to find file with suffix {suffix}") + + # Compile all the files into a .so + cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp")) + consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o")) + + file_name = os.path.splitext(cpp_file)[0] + + # Parse compile flags and build the .o file + with open(file_name + "_compile_flags.json") as f: + compile_flags = json.load(f) + + compile_options = BuildOptionsBase(**compile_flags) + object_builder = CppBuilder( + name=file_name, + sources=cpp_file, + BuildOption=compile_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() + + _run_command_and_check(compile_cmd) + + # Parse linker flags and build the .so file + with open(file_name + "_linker_flags.json") as f: + linker_flags = json.load(f) + + linker_options = BuildOptionsBase(**linker_flags) + so_builder = CppBuilder( + name=os.path.split(so_path)[-1], + sources=[output_o, consts_o], + BuildOption=linker_options, + output_dir=so_path, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + _run_command_and_check(link_cmd) + + # mmapped weights + serialized_weights_filename = file_name + "_serialized_weights.bin" + if serialized_weights_filename in aoti_files: + with open(serialized_weights_filename, "rb") as f_weights: + serialized_weights = f_weights.read() + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(serialized_weights) + + return output_so + + +def package_aoti(aoti_output_dir: str) -> str: + """ + Saves the AOTInductor generated files to the PT2Archive format. + """ + + # Add a makefile and python script + build_package_filename = "build_package.py" + with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f: + f.write(build_package_contents) + + with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f: + f.write(f"all:\n\tpython3 {build_package_filename}\n") + + if config.aot_inductor.output_path.endswith(".so"): + raise RuntimeError( + "Unable to save package as a .so. It should be a .pt2 format or a directory." + ) + elif config.aot_inductor.output_path.endswith(".pt2"): + # Save using the PT2 packaging format + # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) + archive_path = config.aot_inductor.output_path + + with PT2ArchiveWriter(archive_path) as archive_writer: + package_files = glob.glob(f"{aoti_output_dir}/*") + + for path in package_files: + filename = os.path.basename(path) + archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path) + + return archive_path + + else: + # Directly put the files into the directory, without any archiving + return aoti_output_dir + + +def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg] + if path.endswith(".so"): + raise RuntimeError( + "Unable to load .so. It should be a .pt2 format or a directory." + ) + + elif path.endswith(".pt2"): + so_path = os.path.splitext(path)[0] + with PT2ArchiveReader(path) as archive_reader: + file_names = archive_reader.get_file_names() + + with tempfile.TemporaryDirectory() as tmp_dir: + archive_reader.extractall(tmp_dir) + file_names = archive_reader.get_file_names() + aoti_files = [ + file for file in file_names if file.startswith(AOTINDUCTOR_DIR) + ] + + so_path = compile_so(tmp_dir, aoti_files, so_path) + + else: + assert os.path.isdir(path), "Must specify a directory or a .pt2 file" + aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(path) + for file in files + ] + so_path = compile_so(path, aoti_files, path) + + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): # type: ignore[no-untyped-def] + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized diff --git a/.venv/Lib/site-packages/torch/_inductor/package/pt2_archive_constants.py b/.venv/Lib/site-packages/torch/_inductor/package/pt2_archive_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..81ad5dd773859dd45974a6f256fbe9b83434878e --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/package/pt2_archive_constants.py @@ -0,0 +1,16 @@ +ARCHIVE_ROOT_NAME = "package" +ARCHIVE_FORMAT_PATH = "archive_format" +MODELS_DIR = "models/" +MODELS_FILENAME_FORMAT = "models/{}.json" +AOTINDUCTOR_DIR = "data/aotinductor/" +WEIGHTS_DIR = "data/weights/" +WEIGHT_FILENAME_PREFIX = "weight_" +CONSTANTS_DIR = "data/constants/" +TENSOR_CONSTANT_FILENAME_PREFIX = "tensor_" +CUSTOM_OBJ_FILENAME_PREFIX = "custom_obj_" +SAMPLE_INPUTS_DIR = "data/sample_inputs/" +SAMPLE_INPUTS_FILENAME_FORMAT = "data/sample_inputs/{}.pt" +EXTRA_DIR = "extra/" +MODULE_INFO_PATH = "extra/module_info.json" + +ARCHIVE_VERSION = 0 diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/__init__.py b/.venv/Lib/site-packages/torch/_inductor/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79d6b14ba8c9c864a255fbb51081057a2e3558a9 Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..464c0c27fb50c213f5a5705dc8e96c11bcc0e8da Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc b/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1a75b247f49b8733b6250534143ce5f416cd18 Binary files /dev/null and b/.venv/Lib/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc differ diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/autotune_cache.py b/.venv/Lib/site-packages/torch/_inductor/runtime/autotune_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf8af03c4ccf33b98fa0e27b408d84b61f2aa1d --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/autotune_cache.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import os +import os.path +from typing import Dict, List, Optional, Tuple +from typing_extensions import override + +import torch +from torch.utils._triton import has_triton_package + +from ..remote_cache import ( + JsonDataTy, + RemoteCache, + RemoteCacheBackend, + RemoteCacheJsonSerde, +) + + +if has_triton_package(): + from triton import Config + +log = logging.getLogger(__name__) + + +_InductorMetaTy = Dict[str, object] + + +@dataclasses.dataclass +class AutotuneCache: + configs_hash: str + filename: str + local_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None + remote_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None + + # Create a AutotuneCache. Returns None if none of the caches can be used. + @staticmethod + def create( + inductor_meta: _InductorMetaTy, filename: str, configs_hash: str + ) -> Optional[AutotuneCache]: + cache = AutotuneCache(configs_hash, filename) + cache._setup_local_cache(inductor_meta, filename) + cache._setup_remote_autotune_cache(inductor_meta, filename) + if cache.local_cache or cache.remote_cache: + return cache + else: + return None + + # Read the best config options from the most local cache and return it. + def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy]]: + if local_cache := self.local_cache: + cache, key = local_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + if remote_cache := self.remote_cache: + cache, key = remote_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + return None + + # Read the best config options from the most local cache and figure out + # which `configs` represents that option. + def read_best( + self, inductor_meta: _InductorMetaTy, configs: List[Config] + ) -> Optional[Config]: + if best := self._read(inductor_meta): + return _load_cached_autotuning( + best, self.configs_hash, configs, inductor_meta + ) + return None + + # Set up local filesystem caching information + def _setup_local_cache(self, inductor_meta: _InductorMetaTy, filename: str) -> None: + if not inductor_meta.get("autotune_local_cache", True): + return + + cache_filename = os.path.splitext(filename)[0] + ".best_config" + local_cache = RemoteCache(_LocalAutotuneCacheBackend(), RemoteCacheJsonSerde()) + self.local_cache = (local_cache, cache_filename) + + # Set up remote caching information + def _setup_remote_autotune_cache( + self, inductor_meta: _InductorMetaTy, filename: str + ) -> None: + if not _should_use_remote_autotune_cache(inductor_meta): + return + + remote_cache = _create_cache( + inductor_meta, + self.configs_hash, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + "autotune-best-config-v2", + ) + if not remote_cache: + return + + # we already sha256 hash the source contents + remote_cache_key = os.path.basename(filename) + self.remote_cache = (remote_cache, remote_cache_key) + + # Save the config in the caches + def save( + self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False + ) -> None: + data = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + "configs_hash": self.configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + } + + if local_cache := self.local_cache: + cache, key = local_cache + cache.put(key, data) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, key) + + if remote_cache := self.remote_cache: + cache, key = remote_cache + cache.put(key, data) + + +def _should_use_remote_autotune_cache(inductor_meta: Dict[str, object]) -> bool: + if (config := inductor_meta.get("autotune_remote_cache")) is not None: + return bool(config) + if not inductor_meta.get("is_fbcode"): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:autotune_memcache_version" + ) + + +def _load_cached_autotuning( + best_config: Dict[str, JsonDataTy], + configs_hash: str, + configs: List[Config], + inductor_meta: Dict[str, object], +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( + "found_by_coordesc", False + ): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +def _create_cache( + inductor_meta: Dict[str, object], + configs_hash: str, + fb_cache_cls: str, + oss_cache_cls: str, + salt: str, +) -> Optional[RemoteCache[JsonDataTy]]: + backend_hash = inductor_meta.get("backend_hash", None) + if backend_hash is None: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + return None + + assert isinstance(backend_hash, str) + + key = backend_hash + configs_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + try: + if inductor_meta.get("is_fbcode"): + import torch._inductor.fb.remote_cache + + cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls) + return cache_cls(key) + else: + import torch._inductor.remote_cache + + cache_cls = getattr(torch._inductor.remote_cache, oss_cache_cls) + return cache_cls(key) + except Exception: + log.warning("Unable to create a remote cache", exc_info=True) + return None + + +class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): + @override + def get(self, key: str) -> Optional[bytes]: + try: + with open(key, "rb") as fd: + return fd.read() + except FileNotFoundError: + return None + + @override + def put(self, key: str, data: bytes) -> None: + with open(key, "wb") as fd: + fd.write(data) diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/benchmarking.py b/.venv/Lib/site-packages/torch/_inductor/runtime/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..abfe27b2f0d117f1ddd2e69cf31399111309d5c1 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/benchmarking.py @@ -0,0 +1,204 @@ +import time +from functools import cached_property, wraps +from itertools import chain +from statistics import median +from typing import Any, Callable, Dict, List, Tuple +from typing_extensions import Concatenate, ParamSpec, Self, TypeVar + +import torch +from torch._dynamo.utils import counters + + +logger = torch._logging.getArtifactLogger(__name__, "benchmarking") + + +MILLISECONDS_PER_SECOND = 1000 + +P = ParamSpec("P") +T = TypeVar("T") + + +def maybe_time( + fn: Callable[Concatenate[Any, P], T] +) -> Callable[Concatenate[Any, P], T]: + """Wrapper that logs the duration of `fn`, in milliseconds, along with a representation + of the function's args and kwargs, if logging is enabled. It is expected that `fn` is + a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from + declaring this directly. If logging is disabled, this becomes a no-op. + """ + + # no-op if benchmarking-specific logging is disabled + if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"): + return fn + + @wraps(fn) + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: + start_t = time.perf_counter() + result = fn(*args, **kwargs) + logger.debug( + "Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.", + self.__class__.__name__, + fn.__name__, + args, + kwargs, + (time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND, + ) + return result + + return wrapper + + +def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]: + """Wrapper that increments relevant dynamo counters on `fn` call. It is expected that + `fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent + us from declaring this directly. The counter incrementation follows the formula, + + `counters["inductor"]["benchmarking.Foo.bar] += 1` + + where `Foo` is the class whose' instance called the function, and `bar` is the function name. + """ + + @wraps(fn) + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: + counters["inductor"][ + "benchmarking." + self.__class__.__name__ + "." + fn.__name__ + ] += 1 + return fn(self, *args, **kwargs) + + return wrapper + + +class Benchmarker: + def __init__(self: Self) -> None: + pass + + @maybe_time + @count + def benchmark( + self: Self, + fn: Callable[..., Any], + fn_args: Tuple[Any], + fn_kwargs: Dict[str, Any], + **kwargs: Any, + ) -> float: + """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the + actual runtime calculation is dictated by the benchmarking implementation, but may be + one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around + device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises + `ValueError(...)` if we can't safely infer the device type of `fn`; for example, + if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device + types are found. + + Arguments: + - fn: The function to benchmark. + - fn_args: The function's arguments. + - fn_kwargs: The function's kwargs. + + Keyword Arguments: + - **kwargs: The benchmarking implementation's kwargs. + + Returns: + - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. + """ + inferred_device = None + for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): + if not isinstance(arg_or_kwarg, torch.Tensor): + continue + if inferred_device is None: + inferred_device = arg_or_kwarg.device + elif arg_or_kwarg.device != inferred_device: + raise ValueError( + "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + ) + if inferred_device is None: + raise ValueError( + "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 + ) + _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + if inferred_device == torch.device("cpu"): + return self.benchmark_cpu(_callable, **kwargs) + # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking + # implementation which was written specifically with CUDA devices in mind, we may want to + # explore alternate implementations for other device types. + return self.benchmark_gpu(_callable, **kwargs) + + @maybe_time + @count + def benchmark_cpu( + self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100 + ) -> float: + """Benchmark the CPU callable, `_callable`, and return the median runtime, + in milliseconds. + + Arguments: + - _callable: The CPU callable to benchmark. + + Keyword Arguments: + - warmup: Optionally, the duration, in milliseconds, to run `_callable` + before benchmarking starts. + - rep: Optionally, the duration, in milliseconds, to run `_callable` + during benchmarking. + + Returns: + - The median runtime of `_callable`, in milliseconds. + """ + + def run_for(ms: int) -> List[float]: + timings = [] + run_start_t = time.perf_counter() + while True: + start_t = time.perf_counter() + _callable() + end_t = time.perf_counter() + timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND) + if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms: + break + return timings + + run_for(warmup) + return median(run_for(rep)) + + @count + def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float: + raise NotImplementedError + + +class TritonBenchmarker(Benchmarker): + @cached_property + @maybe_time + @count + def triton_do_bench(self: Self) -> Callable[..., Any]: + """Lazily import Triton's `do_bench`.""" + try: + from triton.testing import do_bench + except ImportError as e: + raise NotImplementedError("requires Triton") from e + return do_bench + + @maybe_time + @count + def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float: + """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds. + + Arguments: + - _callable: The GPU callable to benchmark. + + Keyword Arguments: + - quantiles: Optionally, a tuple of floats denoting the requested quantiles. + - return_mode: Optionally, the requested return mode. Currently, Triton's + `do_bench` supports min, max, mean, and median return modes. + - **kwargs: Additional kwargs passed to Triton's `do_bench`. + + Returns: + - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified, + this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified, + this is the requested return mode. Otherwise, this is the median. + """ + if "quantiles" in kwargs: + return self.triton_do_bench(_callable, **kwargs)[0] + elif "return_mode" in kwargs: + return self.triton_do_bench(_callable, **kwargs) + return self.triton_do_bench(_callable, **kwargs, return_mode="median") + + +benchmarker = TritonBenchmarker() diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py b/.venv/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..bf09590f0d5fdc752843417c0ef0da4035d18082 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py @@ -0,0 +1,68 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import os +import sys +import warnings +from types import ModuleType +from typing import Any, Callable, Dict + + +def _reload_triton_kernel_in_subproc(reload_module, kernel_name): + return _module_to_triton_kernel(reload_module(), kernel_name) + + +def _module_to_triton_kernel(mod, kernel_name): + kernel = getattr(mod, kernel_name) + kernel._reload_in_subproc = functools.partial( + _reload_triton_kernel_in_subproc, + mod._reload_in_subproc, + kernel_name, + ) + return kernel + + +def _reload_python_module_in_subproc(key, path): + codecache = sys.modules.get("torch._inductor.codecache") + if codecache: + return codecache.PyCodeCache.load_by_key_path(key, path) + else: + return _reload_python_module(key, path) + + +def _reload_python_module(key, path): + with open(path) as f: + try: + code = compile(f.read(), path, "exec", dont_inherit=True) + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + mod = ModuleType(f"{__name__}.{key}") + mod.__file__ = path + mod.key = key # type: ignore[attr-defined] + exec(code, mod.__dict__, mod.__dict__) + sys.modules[mod.__name__] = mod + return mod + + +@functools.lru_cache(None) +def _set_triton_ptxas_path() -> None: + if os.environ.get("TRITON_PTXAS_PATH") is not None: + return + ptxas_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas") + ) + if not os.path.exists(ptxas_path): + return + if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK): + os.environ["TRITON_PTXAS_PATH"] = ptxas_path + else: + warnings.warn(f"{ptxas_path} exists but is not an executable") + + +def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]): + _set_triton_ptxas_path() + os.environ.update(extra_env) + load_kernel().precompile(warm_cache_only=True) diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py b/.venv/Lib/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..a88a21c35600c14fc253c4c689a3ffc9e5547833 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -0,0 +1,304 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +from typing import Callable, Optional + +from .hints import TRITON_MAX_BLOCK +from .runtime_utils import red_text, triton_config_to_hashable + + +try: + import triton +except ImportError: + triton = None + +log = logging.getLogger(__name__) + + +def get_field(config, name): + if name == "num_warps": + return config.num_warps + elif name == "num_stages": + return config.num_stages + else: + return config.kwargs.get(name, None) + + +def set_field(config, name, value): + if name == "num_warps": + config.num_warps = value + elif name == "num_stages": + config.num_stages = value + else: + config.kwargs[name] = value + + +class CoordescTuner: + """ + The coordinate descent tuner. Tune one field/coordinate at a time. + + TODO will it be necessary to tune multiple fields simultaneously. + + + TODO: what if both increasing and decreasing a field can improve perf. + i.e., there are multiple local optima.. + """ + + def __init__( + self, is_mm=False, name="unknown", size_hints=None, inductor_meta=None + ): + self.is_mm = is_mm # we will tune num_stages for mm + self.cached_benchmark_results = {} + self.name = name + self.size_hints = size_hints + self.inductor_meta = inductor_meta or {} + + def prefix_to_size_hint(self, prefix: str) -> Optional[int]: + size_hint_idx = {"X": 0, "Y": 1, "Z": 2, "R": -1}[prefix] + + have_size_hint = ( + self.size_hints is not None + and len(self.size_hints) > 0 + and len(self.size_hints) > size_hint_idx + ) + return self.size_hints[size_hint_idx] if have_size_hint else None + + def get_config_max(self, prefix: str) -> int: + max_block = TRITON_MAX_BLOCK[prefix] + size_hint = self.prefix_to_size_hint(prefix) + return min(max_block, size_hint) if size_hint is not None else max_block + + def get_warpsmax(self): + # Currently, CUDA has a maximum of 1024 threads, so 32 is the max + # number of warps. + return 1024 // 32 + + def cache_benchmark_result(self, config, timing): + self.cached_benchmark_results[triton_config_to_hashable(config)] = timing + + def lookup_in_cache(self, config): + return self.cached_benchmark_results.get(triton_config_to_hashable(config)) + + def call_func(self, func, config): + found = self.lookup_in_cache(config) + if found is not None: + log.debug(" CACHED") + return found + timing = func(config) + self.cache_benchmark_result(config, timing) + return timing + + @property + def tunable_fields(self): + out = [ + "XBLOCK", + "YBLOCK", + "ZBLOCK", + # NOTE: we should not tune RBLOCK for persistent reduction. + # We rely on the fact that persistent reduction's triton.Config + # does not have the RBLOCK field to guarantee that. + "RBLOCK", + # the following 3 are for mm + "BLOCK_M", + "BLOCK_N", + "BLOCK_K", + "num_warps", + ] + if self.is_mm: + out.append("num_stages") + + return out + + def value_too_large(self, name: str, val: int) -> bool: + if name in {"XBLOCK", "YBLOCK", "ZBLOCK", "RBLOCK"}: + return val > self.get_config_max(name[0]) + if name == "num_warps": + return val > self.get_warpsmax() + + return False + + def get_neighbour_values(self, name, orig_val, radius=1, include_self=False): + """ + Get neighbour values in 'radius' steps. The original value is not + returned as it's own neighbour. + """ + assert radius >= 1 + + def update(cur_val, inc=True): + if name == "num_stages": + if inc: + return cur_val + 1 + else: + return cur_val - 1 + else: + if inc: + return cur_val * 2 + else: + return cur_val // 2 + + out = [] + # increment loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, True) + if self.value_too_large(name, cur_val): + break + out.append(cur_val) + + # decrement loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, False) + if cur_val <= 0: + break + out.append(cur_val) + + if include_self: + out.append(orig_val) + return out + + @staticmethod + def has_improvement(baseline, test): + threshold = 0.001 # 0.1% + return test is not None and test < baseline * (1 - threshold) + + def check_all_tuning_directions( + self, + func: Callable[["triton.Config"], float], + best_config, + best_timing, + ): + """ + Check all directions. We only do this once the regular coordinate + descent tuning find no better choices any more. + We only have a few tunable fields, so this should be fine. + """ + candidate_values_list = [] + effective_fields = [] + for field in self.tunable_fields: + old_value = get_field(best_config, field) + if old_value is None: + continue + candidate_values = self.get_neighbour_values( + field, + old_value, + radius=self.inductor_meta.get("coordinate_descent_search_radius", 1), + include_self=True, + ) + candidate_values_list.append(candidate_values) + effective_fields.append(field) + + choices = itertools.product(*candidate_values_list) + improved = False + for choice in choices: + assert len(choice) == len(effective_fields) + candidate_config = copy.deepcopy(best_config) + for new_val, field in zip(choice, effective_fields): + set_field(candidate_config, field, new_val) + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config = candidate_config + best_timing = candidate_timing + + return improved, best_config, best_timing + + def compare_config(self, func, candidate_config, best_config, best_timing): + """ + Check if candidate_config is better than best_config. + + Return a touple of (compare_result, candidate_timing). + compare_result is true iff candidate_config is better. + """ + log.debug("Try config %s", candidate_config) + try: + candidate_timing = self.call_func(func, candidate_config) + except Exception as e: + log.debug("Got exception %s", e) + return False, float("inf") + + if self.has_improvement(best_timing, candidate_timing): + log.debug( + "Tune from %s %f -> %s %f", + best_config, + best_timing, + candidate_config, + candidate_timing, + ) + + return True, candidate_timing + return False, candidate_timing + + def autotune( + self, + func: Callable[["triton.Config"], float], + baseline_config: "triton.Config", + baseline_timing: Optional[float] = None, + ) -> "triton.Config": + if baseline_timing is None: + baseline_timing = self.call_func(func, baseline_config) + + log.debug("= Do coordinate descent tuning for %s =", self.name) + log.debug( + "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing + ) + improved = True + best_config = baseline_config + best_timing = baseline_timing + tunable_fields = self.tunable_fields + + while improved: + improved = False + + for name in tunable_fields: + cur_val = get_field(best_config, name) + # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None + if cur_val is None: + continue + + # It's possible that candidate_values is empty. + # E.g., if XBLOCK is 1 initially and size_hint for x is also 1. + # We would not try either larger or smaller XBLOCK in this case. + candidate_values = self.get_neighbour_values(name, cur_val) + + for next_val in candidate_values: + candidate_config = copy.deepcopy(best_config) + set_field(candidate_config, name, next_val) + + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config, best_timing = candidate_config, candidate_timing + + if not improved and self.inductor_meta.get( + "coordinate_descent_check_all_directions" + ): + old_best_timing = best_timing + improved, best_config, best_timing = self.check_all_tuning_directions( + func, best_config, best_timing + ) + + if improved: + msg = red_text( + "Coordinate descend tuning found improvement of %.3fx by looking in all directions." + ) + log.debug( + msg, + old_best_timing / best_timing, + ) + + log.debug( + "Improve from %s %f -> %s %f, %.3fx", + baseline_config, + baseline_timing, + best_config, + best_timing, + baseline_timing / best_timing, + ) + + return best_config diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/halide_helpers.py b/.venv/Lib/site-packages/torch/_inductor/runtime/halide_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..8271a6bdd391ee31bc092afbdceaeb419cbac03d --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/halide_helpers.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +try: + import halide as hl # type: ignore[import-untyped, import-not-found] +except ImportError: + hl = None + +PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +if hl is not None: + PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9) + PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85) + PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53) + PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57) +else: + PHILOX_KEY_A_U32 = None + PHILOX_KEY_B_U32 = None + PHILOX_ROUND_A_U32 = None + PHILOX_ROUND_B_U32 = None + + +def _pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = hl.max(hl.f32(1.0e-7), u1) + th = hl.f32(6.283185307179586) * u2 + r = hl.sqrt(hl.f32(-2.0) * hl.log(u1)) + return r * hl.cos(th), r * hl.sin(th) + + +def _uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + + # TODO: + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132. + assert x.type() == hl.UInt(32) or x.type() == hl.Int(32) + x = hl.cast(hl.Int(32), x) + scale = hl.f64(4.6566127342e-10) + x = hl.select(x < 0, -x - 1, x) + return x * scale + + +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds): + def umulhi(a, b): + a = hl.cast(hl.UInt(64), a) + b = hl.cast(hl.UInt(64), b) + return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF)) + + for _ in range(n_rounds): + _c0, _c2 = c0, c2 + + c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0 + c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1 + c1 = PHILOX_ROUND_B_U32 * _c2 + c3 = PHILOX_ROUND_A_U32 * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A_U32 + k1 = k1 + PHILOX_KEY_B_U32 + + return c0, c1, c2, c3 + + +def halide_philox(seed, c0, c1, c2, c3, n_rounds): + seed = hl.cast(hl.UInt(64), seed) + + assert c0.type().bits() == 32 + + seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF)) + seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF)) + + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +def randint4x(seed, offset, n_rounds): + offset = hl.cast(hl.UInt(32), offset) + _0 = hl.u32(0) + return halide_philox(seed, offset, _0, _0, _0, n_rounds) + + +def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + i1, i2, i3, i4 = randint4x(seed, offset, n_rounds) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + u3 = _uint_to_uniform_float(i3) + u4 = _uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + source = randint(seed, offset, n_rounds) + return _uint_to_uniform_float(source) + + +def randn(seed, offset): + i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + n1, _ = _pair_uniform_to_normal(u1, u2) + return n1 + + +def randint64(seed, offset, low, high): + r0, r1, r2, r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + r0 = hl.cast(hl.UInt(64), r0) + r1 = hl.cast(hl.UInt(64), r1) + + result = r0 | (r1 << 32) + size = high - low + result = result % hl.cast(hl.UInt(64), size) + result = hl.cast(hl.Int(64), result) + low + return result diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/hints.py b/.venv/Lib/site-packages/torch/_inductor/runtime/hints.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3cdf3f7a2faa203be06505b2863af58732d0c9 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/hints.py @@ -0,0 +1,179 @@ +# mypy: allow-untyped-defs +import collections +import typing +from dataclasses import fields +from enum import auto, Enum +from typing import Dict, List, Optional, Union + + +# NOTE: if these fail asserts submit a PR to increase them +TRITON_MAX_BLOCK = { + "X": 2048, + "Y": 1024, + "Z": 1024, + "R": 4096 * 16, # * 16 is multi-kernel only +} + + +class ReductionHint(Enum): + INNER = 0 + OUTER = 1 + OUTER_TINY = 2 + DEFAULT = 3 + + +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + +# Attempt to import AttrsDescriptor from Triton +try: + from triton.compiler.compiler import AttrsDescriptor + + attrs_descriptor_available = True + # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor + attr_desc_fields = {f.name for f in fields(AttrsDescriptor)} + ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields + divisible_by_8_available = "divisible_by_8" in attr_desc_fields +except ImportError: + attrs_descriptor_available = False + +# Define `instance_descriptor` function with clear conditional handling +if attrs_descriptor_available: + + def instance_descriptor( + divisible_by_16=None, + equal_to_1=None, + ids_of_folded_args=None, + divisible_by_8=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor + if ids_of_folded_args_available: + kwargs["ids_of_folded_args"] = ids_of_folded_args + if divisible_by_8_available: + kwargs["divisible_by_8"] = divisible_by_8 + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) + +else: + # Define a namedtuple as a fallback when AttrsDescriptor is not available + instance_descriptor = collections.namedtuple( # type: ignore[no-redef] + "instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], + defaults=[(), (), (), ()], + ) + + +_NUM_THREADS_PER_WARP = 32 + + +class HeuristicType(Enum): + PERSISTENT_REDUCTION = auto() + POINTWISE = auto() + REDUCTION = auto() + SPLIT_SCAN = auto() + TEMPLATE = auto() + USER_AUTOTUNE = auto() + + +class AutotuneHint(Enum): + ELEMENTS_PER_WARP_32 = 0 + + # Triton codegen tries to codegen set of AutotuneHints. + # Enum.__repr__ looks like """ + # which isn't valid python. + # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". + __repr__ = Enum.__str__ + + +class DeviceProperties(typing.NamedTuple): + """Copy device properties into a data structure not requiring torch to be imported""" + + type: str # type: ignore[assignment] + index: int # type: ignore[assignment] + cc: int + major: Optional[int] = None + regs_per_multiprocessor: Optional[int] = None + max_threads_per_multi_processor: Optional[int] = None + multi_processor_count: Optional[int] = None + + @classmethod + def create(cls, device): + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type if torch.version.hip is None else "hip" + device_interface = get_interface_for_device(device) + if device_type == "cuda": + props = device_interface.get_device_properties(device) + return cls( + type=device_type, + index=device.index, + cc=device_interface.get_compute_capability(device), + major=props.major, + regs_per_multiprocessor=props.regs_per_multiprocessor, + max_threads_per_multi_processor=props.max_threads_per_multi_processor, + multi_processor_count=props.multi_processor_count, + ) + return cls( + type=device_type, + index=device.index, + cc=device_interface.get_compute_capability(device), + ) + + +class HalideInputSpec(typing.NamedTuple): + ctype: str + name: str + shape: Optional[List[str]] = None + stride: Optional[List[str]] = None + offset: Optional[str] = None + alias_of: Optional[str] = None + + def bindings_type(self): + if self.ctype in ("half*", "bfloat16*"): + return "uint16_t*" # half not defined + return self.ctype + + def halide_type(self): + if self.ctype == "half*": + return "halide_type_t(halide_type_float, 16)" # half not defined + if self.ctype == "bfloat16*": + return "halide_type_t(halide_type_bfloat, 16)" # half not defined + return f"halide_type_of<{self.ctype.replace('*', '')}>()" + + def is_scalar(self): + return self.shape is None + + def is_buffer(self): + return self.shape is not None + + +class HalideMeta(typing.NamedTuple): + argtypes: List[HalideInputSpec] + target: str + scheduler: Optional[str] = None + scheduler_flags: Optional[Dict[str, Union[int, str]]] = None + cuda_device: Optional[int] = None + + def args(self): + """Command line args to pass to halide generator""" + args = [f"target={self.target}"] + if self.scheduler: + args.append(f"autoscheduler={self.scheduler}") + if self.scheduler_flags: + assert self.scheduler + for k, v in self.scheduler_flags.items(): + args.append(f"autoscheduler.{k}={v}") + return args + + def is_cuda(self): + return self.cuda_device is not None diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/runtime_utils.py b/.venv/Lib/site-packages/torch/_inductor/runtime/runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..94edbd2c3d45240f0655191e7bc3f2710febfb7a --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/runtime_utils.py @@ -0,0 +1,154 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import functools +import getpass +import operator +import os +import re +import tempfile + +import torch + + +def conditional_product(*args): + return functools.reduce(operator.mul, [x for x in args if x]) + + +def ceildiv(numer: int, denom: int) -> int: + return -(numer // -denom) + + +def is_power_of_2(n: int) -> bool: + """Returns whether n = 2 ** m for some integer m.""" + return n > 0 and n & n - 1 == 0 + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: + """ + Return the total number of bytes the arguments of tensor type takes. + + For in/out args, tensor sizes are counted twice: once for reading and + once for writing. + + The first num_in_out_args arguments are in out tensors. + """ + return sum( + arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) + for i, arg in enumerate(args) + if isinstance(arg, torch.Tensor) + ) + + +def triton_config_to_hashable(cfg): + """ + Convert triton config to a tuple that can uniquely identify it. We can use + the return value as a dictionary key. + """ + items = sorted(cfg.kwargs.items()) + items.append(("num_warps", cfg.num_warps)) + items.append(("num_stages", cfg.num_stages)) + return tuple(items) + + +def validate_triton_config(cfg): + # [Note: Triton pre_hook in inductor] + # pre-hook is a lambda function, which we don't attempt to serialize. + # right now, if a pre-hook is attached to the config, it will not be saved; + # and then it won't be used when the config is loaded from cache. + # So we assert - if we do get a pre_hook, it might get ignored after caching. + assert ( + getattr(cfg, "pre_hook", None) is None + ), "triton configs with pre_hooks not supported" + + +def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True): + info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" + slow = ms > 0.012 and gb_per_s < 650 + return red_text(info_str) if color and slow else info_str + + +def get_max_y_grid(): + return 65535 + + +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def default_cache_dir(): + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + return os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) + + +try: + import colorama + + HAS_COLORAMA = True +except ModuleNotFoundError: + HAS_COLORAMA = False + colorama = None # type: ignore[assignment] + + +def _color_text(msg, color): + if not HAS_COLORAMA: + return msg + + return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET + + +def green_text(msg): + return _color_text(msg, "green") + + +def yellow_text(msg): + return _color_text(msg, "yellow") + + +def red_text(msg): + return _color_text(msg, "red") + + +def blue_text(msg): + return _color_text(msg, "blue") + + +def get_first_attr(obj, *attrs): + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +try: + dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type] +except AttributeError: # Compile workers only have a mock version of torch + + @contextlib.contextmanager + def dynamo_timed(key, phase_name=None, fwd_only=True): + yield diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/triton_helpers.py b/.venv/Lib/site-packages/torch/_inductor/runtime/triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..d682a082a14dfe45c9cab1a378e395985bcb1efe --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/triton_helpers.py @@ -0,0 +1,542 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import triton +import triton.language as tl + + +# In the latest triton, math functions were shuffled around into different modules: +# https://github.com/openai/triton/pull/3172 +try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 + math = tl.math +except ImportError: + if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl + + +try: + from triton.language.standard import _log2 +except ImportError: + + def _log2(x): + raise NotImplementedError + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def div_floor_integer(a, b): + # NOTE: a // b is C division, but we want floor division + # Based on c10::div_floor_integer + quot = a // b + remainder = a % b + fixed = tl.where(remainder != 0, quot - 1, quot) + return tl.where((a < 0) != (b < 0), fixed, quot) + + +@triton.jit +def remainder_integer(a, b): + # NOTE: a % b matches C division, not floor division + remainder = a % b + return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def _prod_accumulate(a, b): + return a * b + + +@triton.jit +def prod(input, axis): + return tl.reduce(input, axis, _prod_accumulate) + + +@triton.jit +def minimum(a, b): + mask = a < b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def min2(a, dim): + return tl.reduce(a, dim, minimum) + + +@triton.jit +def max2(a, dim): + return tl.reduce(a, dim, maximum) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def min_with_index(value, index, dim): + return tl.reduce((value, index), dim, minimum_with_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +@triton.jit +def welford_reduce(value, mean, m2, weight, first_iteration): + if first_iteration: + new_weight = tl.full(weight.shape, 1, weight.dtype) + new_mean = value + new_m2 = tl.zeros_like(m2) + else: + delta = value - mean + new_weight = weight + 1 + new_mean = mean + delta / new_weight + new_m2 = m2 + delta * (value - new_mean) + return new_mean, new_m2, new_weight + + +@triton.jit +def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +@triton.jit +def welford(mean, m2, weight, dim): + return tl.reduce((mean, m2, weight), dim, welford_combine) + + +@triton.jit +def device_assert_then(cond, msg, r): + tl.device_assert(cond, msg) + return r + + +@triton.jit +def randint64(seed, offset, low, high): + r0, r1, r2, r3 = tl.randint4x(seed, offset) + r0 = r0.to(tl.uint64) + r1 = r1.to(tl.uint64) + result = r0 | (r1 << 32) + size = high - low + result = result % size.to(tl.uint64) + result = result.to(tl.int64) + low + return result + + +@triton.jit +def _any_combine(a, b): + return a | b + + +@triton.jit +def any(a, dim): + return tl.reduce(a, dim, _any_combine) + + +@triton.jit +def bucketize_binary_search( + values, # 1D tensor + offsets_ptr, + indexing_dtype, + right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op] + OFFSETS_SIZE: int, + BLOCK_SHAPE, # tuple/list of block shape +): + """ + See [Note: Inductor bucketize op] + """ + + low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype) + high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype) + + full_range = OFFSETS_SIZE + 1 + while full_range > 1: + mid = (high + low) // 2 + mask = mid < OFFSETS_SIZE + bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0) + if right: + is_above = values >= bucket_upper_bound + else: + is_above = values > bucket_upper_bound + + low = tl.where(is_above & mask, mid + 1, low) + high = tl.where(is_above, high, mid) + + full_range = (full_range + 1) // 2 + + return low + + +@triton.jit +def pack_value_flag( + value, + flag, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) + return flag.to(DTYPE_PACK) | (uv << bitwidth) + + +@triton.jit +def unpack_value( + pack, + DTYPE_VALUE, + DTYPE_VALUE_AS_UINT, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE) + DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) + return value_uint.to(DTYPE_VALUE, bitcast=True) + + +@triton.jit +def unpack_flag(pack, DTYPE_FLAG): + return pack.to(DTYPE_FLAG) + + +@triton.jit +def exclusive_scan_decoupled_lookback( + scratch_base, + block_value, + index, + combine_fn, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` + DTYPE_PACK: Unsigned type twice the width of block_value + + NOTE: This function is limited to values which are 32-bits or less because + we need to pack (value, flag) into a single unsigned int. + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + DTYPE_VALUE = block_value.dtype + pack = pack_value_flag( + block_value, + tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + if index > 0: + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], DTYPE_VALUE) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + # tl.atomic_load + flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) + while flag == 0: + pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") + flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) + + value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + pack = pack_value_flag( + inclusive_prefix, + tl.full([], 2, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + return exclusive_prefix + + +@triton.jit +def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block, must be 64-bits wide + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + init: Scalar value equal to the identiy of combine_fn + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + if index > 0: + block_value_u64 = block_value.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 1, block_value_u64) + tl.debug_barrier() + flag_one = tl.full([], 1, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], block_value.dtype) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + flag = tl.full([], 0, tl.uint64) + while flag == 0: + flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") + + value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) + value = value_u64.to(block_value.dtype, bitcast=True) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) + tl.debug_barrier() + flag_two = tl.full([], 2, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") + + return exclusive_prefix + + +@triton.jit +def frexp(x): + # TODO(isuruf): use inline_asm_elementwise here + y = libdevice.ilogb(x) + 1 + exponent = tl.where(x == 0, 0, y) + mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) + return mantissa, exponent + + +@triton.jit +def _compare_and_swap_with_index( + x, + idxs, + rnumel, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] + + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + y = tl.reshape(x, shape) + iy = y.to(idtype, bitcast=True) + # slice left/right with 'stride' 2**(n_dims - i - 1) + right_mask = tl.arange(0, 2)[None, :, None].to(idtype) + left_mask = (1 - right_mask).to(idtype) + ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape) + iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape) + ileft = tl.reshape(ileft, x.shape) + iright = tl.reshape(iright, x.shape) + left = ileft.to(x.dtype, bitcast=True) + right = iright.to(x.dtype, bitcast=True) + + # idx + y_idx = tl.reshape(idxs, shape) + left_idx = tl.broadcast_to( + tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + right_idx = tl.broadcast_to( + tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + left_idx = tl.reshape(left_idx, x.shape) + right_idx = tl.reshape(right_idx, x.shape) + + # valid + if rnumel is None: + left_valid_mask = tl.full(x.shape, True, tl.int1) + right_valid_mask = tl.full(x.shape, True, tl.int1) + else: + left_valid_mask = left_idx < rnumel + right_valid_mask = right_idx < rnumel + + # actual compare-and-swap + ix = x.to(idtype, bitcast=True) + + if descending: + cond = left < right + else: + cond = left > right + + if stable: + # When stable sorting, tie break by index + cond = cond | ((left == right) & (left_idx > right_idx)) + + cond = (right_valid_mask > left_valid_mask) | ( + (right_valid_mask == left_valid_mask) & cond + ) + cond = cond ^ flip + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs)) + + return ret.to(x.dtype, bitcast=True), new_idxs + + +@triton.jit +def _bitonic_merge_with_index( + x, + idxs, + rnumel, + stage: tl.constexpr, + alternating: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if alternating: + shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape( + tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape + ) + else: + flip = False + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, idxs = _compare_and_swap_with_index( + x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending + ) + return x, idxs + + +@triton.jit +def sort_with_index( + x, # value + idxs, # index + rnumel, # number of elements + dim: tl.constexpr = None, + stable: tl.constexpr = tl.constexpr(False), + descending: tl.constexpr = tl.constexpr(False), +): + x, idxs = tl.broadcast(x, idxs) + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert( + _dim == len(x.shape) - 1, "only minor dimension is currently supported" + ) + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = _log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, idxs = _bitonic_merge_with_index( + x, + idxs, + rnumel, + i, + alternating=i < n_dims, + n_dims=n_dims, + stable=stable, + descending=descending, + ) + return x, idxs + + +@triton.jit +def select_one(x, mask, dim, keep_dims=False): + idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False) + ix = x.to(idtype, bitcast=True) + iy = tl.sum(ix * mask, dim, keep_dims=keep_dims) + return iy.to(x.dtype, bitcast=True) diff --git a/.venv/Lib/site-packages/torch/_inductor/runtime/triton_heuristics.py b/.venv/Lib/site-packages/torch/_inductor/runtime/triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..fb49c26bacca7aa8337580ac96bf98d59cee8b8b --- /dev/null +++ b/.venv/Lib/site-packages/torch/_inductor/runtime/triton_heuristics.py @@ -0,0 +1,1807 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import builtins +import copy +import functools +import hashlib +import inspect +import logging +import math +import operator +import os +import os.path +import re +import sys +import threading +import time +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch + +from .autotune_cache import AutotuneCache +from .benchmarking import benchmarker +from .coordinate_descent_tuner import CoordescTuner +from .hints import ( + _NUM_THREADS_PER_WARP, + AutotuneHint, + DeviceProperties, + HeuristicType, + ReductionHint, + TileHint, + TRITON_MAX_BLOCK, +) +from .runtime_utils import ( + cache_dir, + ceildiv, + conditional_product, + create_bandwidth_info_str, + dynamo_timed, + get_first_attr, + get_max_y_grid, + get_num_bytes, + next_power_of_2, + triton_config_to_hashable, + validate_triton_config, +) + + +try: + import triton +except ImportError: + triton = None + +if triton is not None: + from triton import Config + from triton.compiler import CompiledKernel + from triton.runtime.autotuner import OutOfResources + from triton.runtime.jit import KernelInterface + + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None + + try: + from triton.backends.compiler import GPUTarget + except ImportError: + GPUTarget = None +else: + Config = object + KernelInterface = object + OutOfResources = object + ASTSource = None + GPUTarget = None + +try: + autograd_profiler = torch.autograd.profiler +except AttributeError: # Compile workers only have a mock version of torch + + class autograd_profiler: # type: ignore[no-redef] + _is_profiler_enabled = False + + +log = logging.getLogger(__name__) + + +def autotune_hints_to_configs( + hints: Set[AutotuneHint], size_hints, block_size: int +) -> List[Config]: + """ + AutotuneHints can be attached to the metadata of triton kernels for providing + suggestions about what to try for autotuning. One reason to do this is if there are + some configs that are only useful in specific scenarios, in which case we can avoid + wasting compile time on autotuning unless we know we are in one of those scenarios. + + Based on those hints, this function will generate a list of additional autotuning + configs to try. + """ + xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...] + configs = [] + + for hint in hints: + if hint == AutotuneHint.ELEMENTS_PER_WARP_32: + if len(size_hints) == 1: + xyz_options = ((block_size // 4, None, None),) + elif len(size_hints) == 2: + xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) + elif len(size_hints) == 3: + xyz_options = ( + (block_size // 4, 1, 1), + (1, block_size // 4, 1), + (1, 1, block_size // 4), + ) + for xyz in xyz_options: + configs.append( + triton_config( + size_hints, + *xyz, + num_elements_per_warp=32, + ) + ) + + return configs + + +def disable_pointwise_autotuning(inductor_meta): + # Autotuning can give different benchmarking results from run to run, and + # therefore we disable autotuning when use_deterministic flag is on. + if inductor_meta.get("are_deterministic_algorithms_enabled"): + return True + return not inductor_meta.get("autotune_pointwise", True) + + +def _dump_launch_params(args, kwargs, launcher, kernel_name): + call_args = [] + call_kwargs = {} + for arg in args: + if isinstance(arg, (int, bool)): + call_args.append(str(arg)) + else: + call_args.append("T") + for k, v in kwargs.items(): + if isinstance(arg, (int, bool)): + call_kwargs[k] = v + else: + call_kwargs[k] = v + for k, v in launcher.config.kwargs.items(): + call_kwargs[k] = v + call_kwargs["num_warps"] = launcher.config.num_warps + call_kwargs["num_stages"] = launcher.config.num_stages + args_str = "" + args_str += ", ".join(call_args) + for k, v in call_kwargs.items(): + args_str += f", {k}={v}" + + abs_path = os.path.abspath(sys.argv[0]) + with open(f"{abs_path}.launch_params", "a") as f: + f.write(f"{kernel_name} | {args_str}\n") + + +class CachingAutotuner(KernelInterface): + """ + Simplified version of Triton autotuner that has no invalidation + key and caches the best config to disk to improve cold start times. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: List[str], # see [Note: clone mutated buffers] + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: Optional[str] = None, + ): + super().__init__() + + assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + # makes sure there are no pre-hooks on any of the triton configs + for cfg in configs: + validate_triton_config(cfg) + + self.fn = fn + self.device_props: DeviceProperties = triton_meta["device"] + self.triton_meta = { + **triton_meta, + "device": self.device_props.index, + "device_type": self.device_props.type, + } + self.inductor_meta = {} if inductor_meta is None else inductor_meta + self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names + self.configs = configs + self.heuristic_type = heuristic_type + self.custom_kernel = custom_kernel + self.cuda_kernel_saved = False + if log.isEnabledFor(logging.DEBUG): + log.debug( + "CachingAutotuner gets %d configs for %s", + len(self.configs), + self.fn.__name__, + ) + for c in self.configs: + log.debug(c) + + self.launchers = [] # type: ignore[var-annotated] + self.lock = threading.Lock() + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = os.path.join( + cache_dir(), + "triton", + str(self.triton_meta.get("device", 0)), + ) + log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) + + self.size_hints = size_hints + self.coordesc_tuner = CoordescTuner( + is_mm=False, + name=self.fn.__name__, + size_hints=size_hints, + inductor_meta=self.inductor_meta, + ) + self.filename = filename + + self.precompile_time_taken_ns = 0 + self.autotune_time_taken_ns = 0 + + def precompile(self, warm_cache_only=False): + with self.lock: + if self.launchers: + return + self.launchers = [] + compiled_binaries = [] + if not self.configs: + raise RuntimeError("No triton configs are available") + for c in self.configs: + try: + compiled_binary, launcher = self._precompile_config( + c, warm_cache_only + ) + except OutOfResources as e: + if len(self.configs) == 1: + # There are no valid Triton configs + raise e + # Skip the config if we run out of resource + continue + self.launchers.append(launcher) + compiled_binaries.append(compiled_binary) + + if len(self.launchers) == 0: + raise RuntimeError( + "No valid triton configs. Report a fatal compilation error" + ) + + seen_configs = set(self.configs) + + device_prop = self.device_props + if ( + self.inductor_meta.get("dynamic_scale_rblock", True) + and self.heuristic_type == HeuristicType.REDUCTION + and self.size_hints is not None + # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type == "cuda" + and device_prop.major + and device_prop.major >= 8 + ): + assert device_prop.regs_per_multiprocessor + assert device_prop.max_threads_per_multi_processor + assert device_prop.multi_processor_count + for triton_config, compiled_binary in zip( + self.configs, compiled_binaries + ): + assert len(self.size_hints) == 2 + xblock = triton_config.kwargs.get("XBLOCK", 1) + rblock = triton_config.kwargs["RBLOCK"] + total_block = (self.size_hints[0] + xblock - 1) // xblock + nreg = getattr(compiled_binary, "n_regs", None) + if nreg is None: + continue + + # make sure rblock is not too small + if rblock <= 64: + continue + + # each SM of A100 has 65536 32-bit registers. To maximize + # the theoretical occupancy, we need run 2048 threads on each + # SM. So each thread should use no more than 65536 / 2048 + # = 32 registers. In cases where occupancy matters, and each + # thread uses too many registers, reduce RBLOCK to reduce + # the register usage. + # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd + # from PLBartForCausalLM, latency improve from + # 7.795ms to 4.883ms. + # + if ( + nreg + <= device_prop.regs_per_multiprocessor + // device_prop.max_threads_per_multi_processor + ): + continue + + nreg_per_warp = nreg * 32 + nreg_per_block = nreg_per_warp * triton_config.num_warps + + # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' + # The formula below is a tighter upper bound since we have the assumption that + # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor + # due to the if condition above and: + # regs_per_multiprocessor / nreg_per_block + # = regs_per_multiprocessor / (nreg * 32 * num_warps) + # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) + # = max_threads_per_multi_processor / (32 * num_warps) + # Using a tigher upper bound can reveal more optimization opportunities. + max_blocks_per_sm = max( + device_prop.regs_per_multiprocessor // nreg_per_block, 1 + ) + + if ( + total_block + <= max_blocks_per_sm * device_prop.multi_processor_count + ): + # no need to improve occupancy + continue + new_config = copy.deepcopy(triton_config) + new_config.kwargs["RBLOCK"] = rblock // 2 + if new_config in seen_configs: + continue + seen_configs.add(new_config) + log.debug( + "Dynamically scale down RBLOCK from TritonConfig(%s) and get a new TritonConfig(%s)", + triton_config, + new_config, + ) + self.launchers.append( + self._precompile_config(new_config, warm_cache_only)[1] + ) + self.configs = None + + def get_device_interface(self): + # this code cannot run in compile workers, because it imports from torch + from torch._dynamo.device_interface import get_interface_for_device + + return get_interface_for_device(self.device_props.type.replace("hip", "cuda")) + + def _precompile_config(self, cfg: Config, warm_cache_only: bool): + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + for k, v in cfg.kwargs.items(): + if self.device_props.type == "hip": + if k == "matrix_instr_nonkdim": + compile_meta["matrix_instr_nonkdim"] = v + continue + if k == "waves_per_eu": + compile_meta["waves_per_eu"] = v + continue + compile_meta["constants"][self.fn.arg_names.index(k)] = v + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = self.inductor_meta.get( + "assert_indirect_indexing", True + ) and not self.inductor_meta.get("is_hip", False) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + if ASTSource: + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), + ) + + cc_str = str(compile_meta["cc"]) + if "gfx10" in cc_str or "gfx11" in cc_str: + rocm_warp_size = 32 + else: + rocm_warp_size = 64 + + if GPUTarget: + target = GPUTarget( + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size if torch.version.hip else 32, + ) + else: + target = ( + (compile_meta["device_type"], compile_meta["cc"]) + if not torch.version.hip + else [ + compile_meta["device_type"], + compile_meta["cc"], + rocm_warp_size, + ] + ) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + } + if self.device_props.type == "hip": + if "waves_per_eu" in compile_meta: + options["waves_per_eu"] = compile_meta["waves_per_eu"] + if "matrix_instr_nonkdim" in compile_meta: + options["matrix_instr_nonkdim"] = compile_meta[ + "matrix_instr_nonkdim" + ] + compile_kwargs = { + "target": target, + "options": options, + } + else: + compile_args = (self.fn,) + compile_kwargs = compile_meta + + if warm_cache_only: + return ( + triton.compile(*compile_args, **compile_kwargs), + None, + ) + + # importing from torch is safe now that precompile has returned + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined] + # need to initialize context + device_interface.synchronize(device_interface.current_device()) + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + binary._init_handles() + + call_args = [ + arg + for i, arg in enumerate(self.fn.arg_names) + if i not in self.fn.constexprs + ] + def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": CompiledKernel.launch_enter_hook, + "launch_exit_hook": CompiledKernel.launch_exit_hook, + "metadata": binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata, + "shared": binary_shared, + } + + scope["num_warps"] = ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ) + + scope["cta_args"] = ( + (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ) + + scope["function"] = get_first_attr(binary, "function", "cu_function") + + def get_launch_args_without_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args before CompiledKernel.launch_metadata is added. + """ + return ( + grid_0, + grid_1, + grid_2, + num_warps, + *cta_args, + shared, + stream, + function, + launch_enter_hook, + launch_exit_hook, + metadata, + ) + + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if binary.launch_enter_hook: + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin.launch_metadata(grid, stream, *args), + launch_enter_hook, + launch_exit_hook, + ) + + else: + + def get_launch_args_with_kernel_launch_metadata( + grid, + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + bin, + launch_enter_hook, + launch_exit_hook, + num_warps, + shared, + cta_args, + args, + ): + """ + Construct launch args after CompiledKernel.launch_metadata is added + by https://github.com/openai/triton/pull/3492 . + """ + return ( + grid_0, + grid_1, + grid_2, + stream, + function, + metadata, + None, + launch_enter_hook, + launch_exit_hook, + ) + + scope["get_launch_args"] = ( + get_launch_args_with_kernel_launch_metadata + if hasattr(binary, "launch_metadata") + else get_launch_args_without_kernel_launch_metadata + ) + + scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + + exec( + f""" + def launcher({', '.join(def_args)}, grid, stream): + if callable(grid): + grid_0, grid_1, grid_2 = grid(grid_meta) + else: + grid_0, grid_1, grid_2 = grid + + args = {', '.join(call_args)}, + launch_args = get_launch_args( + grid, grid_0, grid_1, grid_2, stream, function, + metadata, bin, launch_enter_hook, launch_exit_hook, + num_warps, shared, cta_args, args + ) + runner(*launch_args, *args) + return bin + """.lstrip(), + scope, + ) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.store_cubin = self.inductor_meta.get("store_cubin", False) + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = self.fn + launcher.bin = binary + + return binary, launcher + + def bench(self, launcher, *args, grid, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + # we don't skip configs wiht spilled registers when auto-tuning custom + # (user-written) Triton kernels, as (i) we don't have any knowledge or + # control over the kernel code; (ii) there is empirical evidence that + # for some (complicated) custom Triton kernels, a register-spilling + # config may yield the best latency. + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + def kernel_call(): + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + grid=grid, + stream=stream, + ) + + if with_profiler: + from torch._inductor.utils import do_bench_using_profiling + + return do_bench_using_profiling(kernel_call, warmup=10, rep=40) + + return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True) + + def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: + from ..compile_fx import clone_preserve_strides + + # [Note: clone mutated buffers] + # clone inplace buffers to avoid autotune contaminating them if + # the kernel does in-place stores. avoid cloning other buffers because + # it leads to increase memory use + cloned_args = [] + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_args.append(clone_preserve_strides(arg)) + else: + cloned_args.append(arg) + + cloned_kwargs: Dict[str, Any] = {} + for name, arg in kwargs.items(): + if name in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_kwargs[name] = clone_preserve_strides(arg) + else: + cloned_kwargs[name] = arg + + return cloned_args, cloned_kwargs + + def benchmark_all_configs(self, *args, **kwargs): + with dynamo_timed("CachingAutotuner.benchmark_all_configs"): + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + start_time = time.time_ns() + timings = self.benchmark_all_configs(*args, **kwargs) + benchmark_time_taken_ns = time.time_ns() - start_time + self.launchers = [builtins.min(timings, key=timings.get)] + self.autotune_time_taken_ns = ( + self.precompile_time_taken_ns + benchmark_time_taken_ns + ) + if self.save_cache_hook: + self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns) + + def save_gpu_kernel(self, grid, stream, launcher): + if callable(grid): + grid_x, grid_y, grid_z = grid(launcher.config.kwargs) + else: + grid_x, grid_y, grid_z = grid + + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + assert key is not None, "kernel_name can not be None" + params = { + "mangled_name": launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"], + "grid_x": grid_x, + "grid_y": grid_y, + "grid_z": grid_z, + "x_block": launcher.config.kwargs.get("XBLOCK", 1), + "y_block": launcher.config.kwargs.get("YBLOCK", None), + "z_block": launcher.config.kwargs.get("ZBLOCK", None), + "num_warps": launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps, + "shared_mem": launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared, + "stream": stream, + # User defined triton kernels will have arbitrary kwarg names + "meta": launcher.config.kwargs, + } + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") + binary = launcher.bin.asm[bin_type] + CudaKernelParamCache.set(key, params, binary, bin_type) + + self.cuda_kernel_saved = True + + def coordinate_descent_tuning(self, launcher, *args, **kwargs): + """ + Coordinate descent tuning can be run with or without max-autotune. + + The only difference between these two is the starting config for coordinate_descent tuning. + E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 + and max-autotune figure out C3 is the best. + + Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1; + while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. + """ + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): + # skip triton template + return launcher + + config2launcher = {launcher.config: launcher} + + def benchmark_one_config(config): + with self.lock: + _, launcher = self._precompile_config(config, False) + config2launcher[config] = launcher + + out = self.bench(launcher, *args, **kwargs) + log.debug( + "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", + launcher.config, + out, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + return out + + assert not ( + self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION + and "RBLOCK" in launcher.config.kwargs + ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK" + start_time = time.time_ns() + best_config = self.coordesc_tuner.autotune( + benchmark_one_config, launcher.config, None + ) + coordesc_time_taken_ns = time.time_ns() - start_time + best_config.found_by_coordesc = True + + if self.save_cache_hook: + self.save_cache_hook( + best_config, + self.autotune_time_taken_ns + coordesc_time_taken_ns, + found_by_coordesc=True, + ) + return config2launcher.get(best_config) + + def run(self, *args, grid, stream, **kwargs): + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, grid=grid, **kwargs) + + if not getattr( + self.launchers[0].config, "found_by_coordesc", False + ) and self.inductor_meta.get("coordinate_descent_tuning", False): + self.launchers = [ + self.coordinate_descent_tuning( + self.launchers[0], *args, grid=grid, **kwargs + ) + ] + + (launcher,) = self.launchers + if launcher.store_cubin: + self.save_gpu_kernel(grid, stream, launcher) + + if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1": + _dump_launch_params(args, kwargs, launcher, self.fn.__name__) + + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + # grid can be a tuple of ints or a string. + if isinstance(grid, tuple): + grid_info = str(grid) + else: + grid_info = getattr(grid, "grid_fn_str", "") + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args, + { + "kernel_file": "" if self.filename is None else self.filename, + "kernel_backend": "triton", + "grid": grid_info, + "stream": stream, + }, + ): + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + grid=grid, + stream=stream, + ) + + +def _find_names(obj): + import gc + import inspect + + frame = inspect.currentframe() + while frame is not None: + frame.f_locals + frame = frame.f_back + obj_names = [] + for referrer in gc.get_referrers(obj): + if isinstance(referrer, dict): + for k, v in referrer.items(): + if v is obj: + obj_names.append(k) + return obj_names + + +collected_calls: List[Any] = [] + + +def start_graph(): + collected_calls.clear() + + +def end_graph(output_file): + if len(collected_calls) == 0: + return + overall_time = sum(call[0] for call in collected_calls) + overall_gb = sum(call[1] for call in collected_calls) + cur_file = inspect.stack()[1].filename + summary_str = ( + f"SUMMARY ({cur_file})\n" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s" + ) + print(summary_str) + print() + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.debug("Save profile bandwidth results to %s", output_file) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms/overall_time*100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, + num_gb, + gb_per_s, + suffix=suffix, + color=False, + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception as e: + log.warning( + "failed to write profile bandwidth result into %s: %s", + output_file, + e, + ) + + +class DebugAutotuner(CachingAutotuner): + def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs): + self.regex_filter = regex_filter + self.with_profiler = with_profiler + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, grid, stream): + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + super().run(*args, grid=grid, stream=stream) + (launcher,) = self.launchers + + if self.cached is None: + ms = self.bench( + launcher, *args, grid=grid, with_profiler=self.with_profiler + ) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = ms, num_gb, gb_per_s, kernel_name + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + print( + create_bandwidth_info_str( + ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}" + ) + ) + + +def hash_configs(configs: List[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() + ) + return hasher.hexdigest() + + +def cached_autotune( + size_hints: Optional[List[int]], + configs: List[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 or filename + inductor_meta = {} if inductor_meta is None else inductor_meta + + disabled = inductor_meta.get("force_disable_caches", False) + + # on disk caching logic and/or remote caching + autotune_cache = None + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): + configs = [best_config] + + else: + if disabled: + log.debug("autotune caching is disabled by config.force_disable_caches") + + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + assert tconfig.kwargs["XBLOCK"] == 1 + tconfig.kwargs.pop("XBLOCK") + + if inductor_meta.get("profile_bandwidth"): + return DebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + ) + return CachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + ) + + return decorator + + +def unique_configs(configs: List[Config]): + """Remove duplicate configurations""" + seen = set() + pruned_configs = [] + + for cfg in configs: + key = triton_config_to_hashable(cfg) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + + +def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): + for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): + if numel is None: + continue + block = cfg[f"{label}BLOCK"] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" + f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." + ) + max_block = TRITON_MAX_BLOCK[label] + max_block_str = f'config.triton.max_block["{label}"]' + assert max_block % block == 0, ( + f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" + f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." + ) + + +def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): + # On AMD GPU each warp has 64 lanes which is double the size on NV GPU, + # therefore using half the number of warps here correspondingly. + if torch.version.hip: + max_num_warps = (max_num_warps + 1) // 2 + min_num_warps = (min_num_warps + 1) // 2 + # persistent reduction is register intensive + if register_intensive: + max_num_warps = max_num_warps // 2 + return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps)) + + +def _check_max_grid_x(size_hints, x, num_warps): + # Check if maxGridSize is exceeded - if so then must scale XBLOCK further + max_grid_x = 2147483647 + warp_size = ( + 64 if torch.version.hip else 32 + ) # TODO: query warp size once #129663 is merged + num_blocks = (size_hints[0] + x - 1) // x + + while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints[0]: + x *= 2 # Scale up XBLOCK if grid exceeds limits + num_blocks = num_blocks // 2 + if x >= max_grid_x: + raise AssertionError( + "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" + ) + return x, num_blocks + + +def triton_config( + size_hints, + x, + y=None, + z=None, + num_stages=1, + num_elements_per_warp=256, + min_elem_per_thread=0, +) -> Config: + """ + Construct a pointwise triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + + num_elements_per_warp is a suggestion for controlling how many warps + the triton config should contain. e.g.: if x=16, y=8, z=4 then + num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, + we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's + just a suggestion, and sometimes other adjustment heuristics will + override the num_elements_per_warp. + + min_elem_per_thread controls the minimum number of elements + processed by each thread. It's always enforced. + """ + # Ideally we want to read this from some device config + + # for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK + size_hints = list(reversed(size_hints)) + + maxGridSize = [2147483647, 65535, 65535] + + target = conditional_product(x, y, z) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + if y: + y = min(y, size_hints[1]) + if z: + z = min(z, size_hints[2]) + + # if we are below original block size, scale up where we can; + # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension + while x < min(size_hints[0], TRITON_MAX_BLOCK["X"]) and ( + x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target + ): + x *= 2 + while ( + y + and y < min(size_hints[1], TRITON_MAX_BLOCK["Y"]) + and ( + y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target + ) + ): + y *= 2 + while ( + z + and z < min(size_hints[2], TRITON_MAX_BLOCK["Z"]) + and ( + z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target + ) + ): + z *= 2 + + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) + # we are going to arrive at 2 warps only if bs was too small due to + # numel being too small. However to workaround some ptx bugs we still + # want at least 4 warps if there's enough elements per thread + # given that this is a rare situation, don't expect this to affect perf + # in general + # see https://github.com/pytorch/pytorch/pull/97950 + if conditional_product(x, y, z) >= 128 and not torch.version.hip: + num_warps = max(num_warps, 4) + xnumel = size_hints[0] + ynumel = size_hints[1] if y else None + znumel = size_hints[2] if z else None + + # Increase x to satisfy min_elem_per_thread requirements. + block_size = max( + conditional_product(x, y, z), + min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, + ) + x *= math.ceil(block_size / conditional_product(x, y, z)) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}" + check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def triton_config_reduction( + size_hints, x, r, num_stages=1, num_warps=None, register_intensive=False +) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + + target = conditional_product(x, r) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + r = min(r, size_hints[1]) + + # if we are below original block size, scale up where we can + while x < size_hints[0] and conditional_product(x, r) < target: + x *= 2 + while r < size_hints[1] and conditional_product(x, r) < target: + r *= 2 + + if num_warps is None: + num_warps = conditional_product(x, r) // 128 + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + + while conditional_product(x, r) > target: + if r == 1: + break + r = r // 2 + + cfg = {"XBLOCK": x, "RBLOCK": r} + check_config(cfg, xnumel=size_hints[0]) + assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}" + assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}" + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): + """ + Construct a tile reduction triton config with some adjustment + heuristics based on size_hints. Size_hints is a tuple of numels in + each tile dimension and will be rounded up to the nearest power of 2. + """ + + target = conditional_product(x, y, r) + if conditional_product(*size_hints) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints[0]) + y = min(y, size_hints[1]) + r = min(r, size_hints[2]) + + # if we are below original block size, scale up where we can + while x < size_hints[0] and conditional_product(x, y, r) < target: + x *= 2 + while r < size_hints[2] and conditional_product(x, y, r) < target: + r *= 2 + while y < size_hints[1] and conditional_product(x, y, r) < target: + y *= 2 + + cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r} + num_warps = _num_warps(conditional_product(x, y, r) // 256, min_num_warps=1) + check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) + assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}" + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + inductor_meta = {} if inductor_meta is None else inductor_meta + assert not inductor_meta.get("no_x_dim") + + numel = functools.reduce(operator.mul, size_hints) + bs = max(256, min(numel // 128, 1024)) + + hinted_configs = autotune_hints_to_configs( + inductor_meta.get("autotune_hints", set()), size_hints, bs + ) + + triton_config_with_settings = functools.partial( + triton_config, min_elem_per_thread=min_elem_per_thread + ) + + if len(size_hints) == 1: + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, bs)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + else: + return cached_autotune( + size_hints, + [ + triton_config_with_settings( + size_hints, bs, num_elements_per_warp=256 + ), + triton_config_with_settings( + size_hints, bs // 2, num_elements_per_warp=64 + ), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + if len(size_hints) == 2: + if ( + disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + ) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, 32, 32)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + return cached_autotune( + size_hints, + [ + triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, bs, 1), + triton_config_with_settings(size_hints, 1, bs), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.POINTWISE, + ) + if len(size_hints) == 3: + if disable_pointwise_autotuning(inductor_meta): + return cached_autotune( + size_hints, + [triton_config_with_settings(size_hints, 16, 16, 16)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + return cached_autotune( + size_hints, + [ + triton_config_with_settings(size_hints, 16, 16, 16), + triton_config_with_settings(size_hints, 64, 8, 8), + triton_config_with_settings(size_hints, 8, 64, 8), + triton_config_with_settings(size_hints, 8, 8, 64), + triton_config_with_settings(size_hints, bs, 1, 1), + triton_config_with_settings(size_hints, 1, bs, 1), + triton_config_with_settings(size_hints, 1, 1, bs), + *hinted_configs, + ], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.POINTWISE, + ) + raise NotImplementedError(f"size_hints: {size_hints}") + + +def _reduction_configs( + *, size_hints: List[int], inductor_meta: Dict[str, Any] +) -> List[Config]: + reduction_hint = inductor_meta.get("reduction_hint", None) + assert len(size_hints) == 2 + rnumel = size_hints[-1] + + register_intensive = False + MAX_RBLOCK = 2048 + if ( + size_hints[0] >= 1024 + and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) + >= 10 + ): + # A heuristics to reduce RBLOCK if a kernel potentially need many registers. + # Consider load and reduction since load need move data into registers and + # reduction needs an accumulator. + # + # The magic numbers are a bit arbitrary. + # + # We cannot rely on dynamically scaling down RBLOCK later, since sometimes + # triton makes it to use less registers with worse perf. Check: + # https://github.com/pytorch/pytorch/issues/126463 + # + # The heuristic is a very simple one since registers can be reused. But + # hopefully it can be a good enough indicator. + MAX_RBLOCK = 1024 + register_intensive = True + + contiguous_config = triton_config_reduction( + size_hints, + 1, + (rnumel if 256 <= rnumel < MAX_RBLOCK else MAX_RBLOCK), + register_intensive=register_intensive, + ) + outer_config = triton_config_reduction( + size_hints, 64, 8, register_intensive=register_intensive + ) + tiny_config = triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + min(rnumel, MAX_RBLOCK), + register_intensive=register_intensive, + ) + if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"): + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return [tiny_config] + if disable_pointwise_autotuning(inductor_meta): + return [triton_config_reduction(size_hints, 32, 128)] + return [ + contiguous_config, + outer_config, + tiny_config, + triton_config_reduction(size_hints, 64, 64), + triton_config_reduction(size_hints, 8, 512), + # halve the XBLOCK/RBLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + triton_config_reduction(size_hints, 64, 4, num_warps=8), + ] + + +def reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + assert triton_meta is not None + rnumel = size_hints[-1] + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def persistent_reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + xnumel, rnumel = size_hints + + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) + if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel) + ] + + # TODO(jansel): we should be able to improve these heuristics + if reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = [ + triton_config_reduction( + size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel + ) + ] + for c in configs: + # we don't need RBLOCK for persistent reduction + c.kwargs.pop("RBLOCK") + + if disable_pointwise_autotuning(inductor_meta): + configs = configs[:1] + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def split_scan( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """Heuristic for TritonSplitScanKernel""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints = [1, *size_hints[1:]] + + assert triton_meta is not None + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + + # Fixup configs to enforce the minimum RBLOCK size + min_rblock = inductor_meta.get("min_split_scan_rblock", 256) + for cfg in configs: + if cfg.kwargs["RBLOCK"] < min_rblock: + cfg.kwargs["RBLOCK"] = min_rblock + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.SPLIT_SCAN, + filename=filename, + ) + + +def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None): + """ + Compile a triton template + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def user_autotune( + configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False +): + """ + Compile a user defined triton kernel + """ + defaults = inspect.signature(triton.Config).parameters + default_num_stages = defaults["num_stages"].default + default_num_warps = defaults["num_warps"].default + + if len(configs) == 0: + configs = [ + triton.Config( + {}, num_stages=default_num_stages, num_warps=default_num_warps + ) + ] + else: + configs = [ + triton.Config( + c.get("kwargs", {}), + num_stages=c.get("num_stages", default_num_stages), + num_warps=c.get("num_warps", default_num_warps), + ) + for c in configs + ] + + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def grid(*numels): + """Helper function to compute triton grids""" + if len(numels) == 1: + xnumel, ynumel, znumel = numels[0], None, None + elif len(numels) == 2: + xnumel, ynumel, znumel = numels[1], numels[0], None + elif len(numels) == 3: + xnumel, ynumel, znumel = numels[2], numels[1], numels[0] + else: + raise AssertionError(f"invalid size for numels {len(numels)}") + + def get_grid_dim(numel, block): + if numel is None: + return 1 + if block is None: + return numel + return ceildiv(numel, block) + + def grid_fn(meta): + x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1)) + y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None)) + + max_y_grid = get_max_y_grid() + if znumel is None: + div = ceildiv(y_grid, max_y_grid) + y_grid = ceildiv(y_grid, div) + z_grid = div + else: + z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None)) + torch._check( + y_grid <= max_y_grid, + lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue", + ) + + return ( + x_grid, + y_grid, + z_grid, + ) + + setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010 + + return grid_fn + + +def split_scan_grid(xnumel, rnumel): + def grid_fn(meta): + assert meta.get("XBLOCK", 1) == 1 + return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1) + + grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})" + setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010 + + return grid_fn + + +def grid_combo_kernels( + *numels, num_kernels, min_blocks, is_sequential, default_meta=None +): + """min_blocks is the minimal size of the grid x dimension""" + if not is_sequential: + # round robin dispatch + numels_agg = list(numels) + for i in range(len(numels_agg)): + if isinstance(numels_agg[i], (list, tuple)): + numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301 + kernel_grid_fn = grid(*numels_agg) + + if isinstance(numels[-1], (list, tuple)): + min_blocks_d = max(-min(numels[-1]), 0) * num_kernels + else: + min_blocks_d = None + if min_blocks is None: + assert min_blocks_d is not None + min_blocks = min_blocks_d + else: + assert ( + min_blocks_d is None or min_blocks == min_blocks_d + ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" + else: + # sequential dispatch + seq_numels = list(numels) + # x numels are not used here, just a place holder + seq_numels[-1] = 1024 + for i in range(len(seq_numels) - 1): + if isinstance(seq_numels[i], (list, tuple)): + seq_numels[i] = max(seq_numels[i]) + + kernel_grid_fn = grid(*seq_numels) + + def get_grid_dim(numel, block): + if numel is None: + return 1 + if block is None: + return numel + return ceildiv(numel, block) + + def grid_fn(meta): + assert min_blocks is not None, "min_blocks must be a number" + cuda_grid = list(kernel_grid_fn(meta)) + cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks) + return tuple(cuda_grid) + + def seq_grid_fn(meta): + cuda_grid = list(kernel_grid_fn(meta)) + # x <= 0 means this kernel's x grid is not tunable (x_no_dim is true) + x_grid = sum( + [ + -x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1)) + for x in numels[-1] + ] + ) + cuda_grid[0] = x_grid + return tuple(cuda_grid) + + def grid_fn_default_meta(meta): + return grid_fn(default_meta) + + def seq_grid_fn_default_meta(meta): + return seq_grid_fn(default_meta) + + if default_meta is None: + return grid_fn if not is_sequential else seq_grid_fn + else: + return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta diff --git a/.venv/Lib/site-packages/torch/_lazy/__init__.py b/.venv/Lib/site-packages/torch/_lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3d1de2dfa09c3a1fd1cdad5465ff8edad09c29 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_lazy/__init__.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs + +import torch._C._lazy +from torch.utils._pytree import tree_flatten, tree_unflatten + +from .closure import add_step_closure, run_step_closures + + +def mark_step(device: str = "", wait=False): + """Triggers a mark step, which amounts to + - collecting a group of 'live' lazy tensors to index into the compilation cache + (lowering/compiling their IR graphs if not cached) + - kicking off execution of the compiled function + - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) + """ + # TODO(whc) expand this to include backend hooks and align with XLA backend needs + torch._C._lazy._mark_step(device, [], wait=wait) + + run_step_closures() + + +def wait_device_ops(devices=None): + """Waits for all the async operations on the given devices to complete. + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + if devices is None: + devices = [] + torch._C._lazy._wait_device_ops(devices=devices) + + +def sync_multi(tensors, devices): + """ + Sync the list of lazy tensors so there IR get lowered for the activate backend + and the compiled computation graph get cached. + """ + torch._C._lazy._sync_multi(tensors, devices) + + +def get_tensor_id(tensor): + """Return a unique id of the lazy tensor maintained by LTC""" + return torch._C._lazy._get_tensor_id(tensor) + + +def to_cpu(tensors, devices=None): + devices = devices or ["lazy"] + + flattened, spec = tree_flatten(tensors) + sync_multi(flattened, devices) + return tree_unflatten([t.to("cpu") for t in flattened], spec) + + +def save(tensors, *args, **kwargs): + torch.save(to_cpu(tensors), *args, **kwargs) diff --git a/.venv/Lib/site-packages/torch/_lazy/extract_compiled_graph.py b/.venv/Lib/site-packages/torch/_lazy/extract_compiled_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..528540344d03c74e8573e3f2c3ecae2571a236fc --- /dev/null +++ b/.venv/Lib/site-packages/torch/_lazy/extract_compiled_graph.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import os +from typing import Any, Callable, Dict, List + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions + + +debug = os.environ.get("debug_extract_compiled_graph") is not None + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: Dict[int, int] + graph_input_tensor_ids: List[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_ivalues: List[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_ivalue + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +class ReturnValueHandler: + r""" + When ltc_sync_multi is called on multi tensors, the compiled graph + will contain output only for unique tensors - if a tensor appears multiple + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation + even if the TS graph dedup the output. e.g. for method: + + def forward(self, a): + return a, a + + the TS graph captured by LTC will return a single tensor, but Python method expects 2. + + This class dedup the lazy tensors first to get the index that will be used + to duplicate the eager tensors later. + """ + + def __init__(self, lazy_out_list): + self.index: List[List[int]] = [] + self.total_count = len(lazy_out_list) + + tensor_id_to_idx: Dict[int, int] = {} + for dup_idx, lazy_tensor in enumerate(lazy_out_list): + uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) + if uniq_idx is not None: + self.index[uniq_idx].append(dup_idx) + else: + uniq_idx = len(self.index) + self.index.append([dup_idx]) + tensor_id_to_idx[id(lazy_tensor)] = uniq_idx + + def duplicate_eager_tensors(self, eager_tensor_list): + duplicated_list = [None] * self.total_count + assert len(eager_tensor_list) == len(self.index) + + for uniq_idx, eager_tensor in enumerate(eager_tensor_list): + for dup_idx in self.index[uniq_idx]: + duplicated_list[dup_idx] = eager_tensor + return duplicated_list + + +def force_lazy_device(model: fx.GraphModule): + """ + Factory methods in a Fx graph may create tensors for a specific eager devices. + If we take no actions, those eager tensors will be mixed with lazy tensors and + cause crash. This method overwrite those eager device to lazy device. + """ + + def tolazydevice(dev): + if isinstance(dev, torch.device): + return torch.device("lazy", index=dev.index) + return dev + + def hasDeviceArg(args, kwargs): + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) + + for nd in model.graph.nodes: + nd.args = tuple(tolazydevice(arg) for arg in nd.args) + nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} + + # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return + # eager tensors on the default device + # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, + # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). + # To force those tensors on the lazy device, we can not simply override + # the device argument since there is no explicit device argument. + # What we are doing here is, for the list of covered tensor factory methods + # we add a lazy device argument explicity. + # + # TODO: This solution is no ideal since we may miss some factory methods. In future + # when we support lazy mode, this method can be replaced by that. + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): + kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. + kwargs["device"] = torch.device("lazy") + nd.kwargs = kwargs + + model.recompile() + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: + """ + Optimize an eager model with LTC and returns a wrapper to execute the + compiled graph directly without retracing. It depends on other mechanisms + like TorchDynamo guards to guarantee the returned wrapper is only called + when it's safe. + """ + lazy_args = [arg.to(device="lazy") for arg in example_inputs] + args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) + force_lazy_device(lazy_model) + + # This line executes lazy tracing and enable us extracting compiled graph later + metrics.reset() + lazy_out = lazy_model(*lazy_args) + fallback_ops = get_fallback_ops() + metrics.reset() + + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(lazy_out, (tuple, list)): + lazy_out = (lazy_out,) + + args_and_out = tuple(lazy_args) + tuple(lazy_out) + return_value_handler = ReturnValueHandler(args_and_out) + if debug: + print("Fx code:\n", model.code) + print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) + + # TODO: this part is TS backend specific for now and will be generalized to + # support XLA + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) + assert len(graph_input_tensor_ids) == len(graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) + + graph_hash = computation.get_graph_hash(args_and_out) + + if debug: + print("graph_hash", graph_hash) + print(f"args_tensor_ids {args_tensor_ids}") + print("tensor ids from device data:", graph_input_tensor_ids) + + # sync the list of output tensors so the computation graph for these + # tensors will be cached. Those computation graphs can be retrieved + # by graph hash later. + lazy.sync_multi(args_and_out, []) + + def optimized_mod(*args): + if len(args_and_out) == 0: + return () + graph_input = graph_input_matcher(args) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) + + assert len(res) == len(args_and_out) + for i, arg in enumerate(args): + # only copy those tensors that get inplace updated + if arg is not res[i]: + arg.copy_(res[i]) + + # skip the args + return res[len(args) :] + + return optimized_mod diff --git a/.venv/Lib/site-packages/torch/_lazy/ir_cache.py b/.venv/Lib/site-packages/torch/_lazy/ir_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0c4037405bba1bf64c2e41c5dc5ea899eb9281 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_lazy/ir_cache.py @@ -0,0 +1,14 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def dump(dot_file_name: str): + """Dump TrieCache in the dot format""" + return torch._C._lazy._dump_ir_cache(dot_file_name) + + +def reset(): + """Clear TrieCache. This is needed in testing to avoid + node reusing between different tests. + """ + return torch._C._lazy._clear_ir_cache() diff --git a/.venv/Lib/site-packages/torch/_lazy/metrics.py b/.venv/Lib/site-packages/torch/_lazy/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4879c3e4800c099b959d4f964b00de2e9cdc2306 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_lazy/metrics.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def reset(): + """Resets all metric counters.""" + torch._C._lazy._reset_metrics() + + +def counter_names(): + """Retrieves all the currently active counter names.""" + return torch._C._lazy._counter_names() + + +def counter_value(name: str): + """Return the value of the counter with the speficied name""" + return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/.venv/Lib/site-packages/torch/_lazy/tensor_factory_functions.py b/.venv/Lib/site-packages/torch/_lazy/tensor_factory_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..98b4ec6c5cea99edb937a6160141af360621823f --- /dev/null +++ b/.venv/Lib/site-packages/torch/_lazy/tensor_factory_functions.py @@ -0,0 +1,49 @@ +import torch + + +""" +tensor_factory_functions defines the list of torch functions that create tensors. +The list is grabbed by searching thru native_functions.yaml by the following +regular expression: + + cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor" + +It's possible that new tensor factory functions are added making this list stale. +Use at your own risk or regenerate the list. +""" +tensor_factory_functions = ( + torch._cudnn_init_dropout_state, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch._empty_affine_quantized, + torch.empty_strided, + torch.eye, + torch.full, + torch.from_file, + torch.hann_window, + torch.hamming_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.ones, + torch.scalar_tensor, + torch.rand, + torch.randint, + torch.randn, + torch.randperm, + torch.range, + torch._efficientzerotensor, + torch.zeros, + torch.tril_indices, + torch.triu_indices, + # Note: the following functions match the regular expression search above but + # they are not available in the torch module. Comment out. + # torch._sparse_coo_tensor_with_dims, + # torch.fft_fftfreq, + # torch.fft_rfftfreq, +) + ( + # torch.tensor is special since it's not in native_functions.yaml + # add it separately + torch.tensor, +) diff --git a/.venv/Lib/site-packages/torch/_lazy/ts_backend.py b/.venv/Lib/site-packages/torch/_lazy/ts_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..6050e9dca4555553bba07eaec11c2ea6137d4b30 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_lazy/ts_backend.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch._C._lazy_ts_backend + + +def init(): + """Initializes the lazy Torchscript backend""" + torch._C._lazy_ts_backend._init() diff --git a/.venv/Lib/site-packages/torch/_library/autograd.py b/.venv/Lib/site-packages/torch/_library/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..ed919df2f7b62506a039eda06265480d49471513 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_library/autograd.py @@ -0,0 +1,241 @@ +# mypy: allow-untyped-defs +import dataclasses +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Protocol + +from torch import _C, _ops, autograd, Tensor +from torch.utils import _pytree + +from . import utils + + +class InfoProtocol(Protocol): + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +@dataclasses.dataclass +class Info: + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable: + name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}" + + has_kwarg_only_args = utils.has_kwarg_only_args(op._schema) + + @dataclass + class Metadata: + keyset: _C.DispatchKeySet + keyword_only_args: Dict[str, Any] + + def forward_no_grad(*args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + return result + + def forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + if info._setup_context_fn: + # The Dispatcher will remove args that are equal to their default + # values from (args, kwargs). We're going to add it back so that + # the user can access them. + # + # This is OK to do: The Dispatcher removed the args for serialization + # FC/BC reasons (that is, a graph will not store args that are equal + # to their default values), but that doesn't matter here. If the user + # adds a new default arg, then they must update + # their setup_context (along with the rest of their operator + # registrations) + args, kwargs = utils.fill_defaults(op._schema, args, kwargs) + + if has_kwarg_only_args: + info._setup_context_fn( + ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result + ) + else: + info._setup_context_fn(ctx=ctx, inputs=args, output=result) + return result + + def backward(ctx, *grads): + if info._backward_fn: + try: + prev_needs_input_grad = ctx.needs_input_grad + ctx.needs_input_grad = ctx.needs_input_grad[:-1] + result = info._backward_fn(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + if isinstance(result, tuple): + return (*result, None) + return result, None + raise RuntimeError( + f"Trying to backward through {op} but no autograd " + f"formula was registered. " + f"Please use register_autograd to add one." + ) + + Generated = type( + name, + (autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + }, + ) + + schema = op._schema + if any( + utils.is_tensorlist_like_type(a.type) + for a in (*schema.arguments, *schema.returns) + ): + Generated = supports_tensorlist(Generated) + + # The dispatcher passes any keyword-only-args as kwargs and the + # rest of the args (even if specified as kwargs) as args. + def autograd_impl(keyset, *args, **keyword_only_args): + if _C.is_grad_enabled() and _pytree.tree_any_only( + Tensor, lambda x: x.requires_grad, args, not_list_of_tensor + ): + result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] + else: + result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) + return result + + return autograd_impl + + +def supports_tensorlist(cls: Any) -> Any: + """Allows a given autograd.Function class to support List[Tensor] inputs/outputs. + + Regular autograd.Function has a constraint that it only directly supports autograd for + Tensors. Applying @supports_tensorlist enables an autograd.Function to support + autograd for List[Tensor] inputs and outputs. + """ + orig_forward = cls.forward + orig_backward = cls.backward + orig_apply = cls.apply + + @dataclass + class Metadata: + input_spec: spec_t + output_spec: Optional[spec_t] = None + result_is_tuple: Optional[bool] = None + + def new_forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + if not isinstance(metadata, Metadata): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.forward directly. " + "You should probably be calling .apply instead. " + "Please file an issue if not." + ) + args = unflatten(list(args), metadata.input_spec) + result = orig_forward(ctx, *args) + metadata.result_is_tuple = isinstance(result, tuple) + if not metadata.result_is_tuple: + result = (result,) + flat_result, output_spec = flatten(result, not_list_of_tensor) + metadata.output_spec = output_spec + + if hasattr(ctx, "_pt_metadata"): + raise RuntimeError( + "Please don't set ctx._pt_metadata; PyTorch uses it to store info" + ) + ctx._pt_metadata = metadata + + return tuple(flat_result) + + def new_backward(ctx, *grads): + if not hasattr(ctx, "_pt_metadata"): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.backward directly. " + "This will automatically get called by PyTorch autograd. " + "Please file an issue if you need this." + ) + + metadata = ctx._pt_metadata + grads = unflatten(list(grads), metadata.output_spec) + + # If the user's input is ([x, y, z], w), + # then needs_input_grad is (bool, bool, bool, bool, bool). + # We need to + # 1. get rid of the additional bool (which comes from the extra + # `metadata input`) + # 2. unflatten to get the right structure. + prev_needs_input_grad = ctx.needs_input_grad + try: + ctx.needs_input_grad = unflatten( + list(ctx.needs_input_grad[:-1]), metadata.input_spec + ) + grad_inputs = orig_backward(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + + if not isinstance(grad_inputs, tuple): + grad_inputs = (grad_inputs,) + # Assume that any Nones in the backward are Tensors. + # If the forward has an arg that is [1, 2, 3], the backward should + # return None as the grad. + # If the forward has an arg that is [tensor, tensor], the backward + # may return [None, None], [grad, None], [None, grad], or [grad, grad]. + flat_grad_inputs, grad_inputs_spec = flatten( + grad_inputs, not_list_of_optional_tensor + ) + if grad_inputs_spec != metadata.input_spec: + raise RuntimeError( + f"Expected the return from backward to be of the same structure " + f"as the inputs. Got: {grad_inputs_spec} (return from backward), " + f"{metadata.input_spec} (inputs)" + ) + return tuple(flat_grad_inputs + [None]) + + def new_apply(*args): + flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor) + metadata = Metadata(input_spec) + result = orig_apply(*flat_args, metadata) # type: ignore[misc] + assert metadata.output_spec is not None + result = unflatten(list(result), metadata.output_spec) + if not metadata.result_is_tuple: + assert isinstance(result, tuple) + assert len(result) == 1 + return result[0] + return result + + cls.forward = new_forward + cls.backward = new_backward + cls.apply = new_apply + return cls + + +def not_list_of_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(not isinstance(l, Tensor) for l in tree) + return True + + +def not_list_of_optional_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(l is not None and not isinstance(l, Tensor) for l in tree) + return True + + +flatten = _pytree.tree_flatten +unflatten = _pytree.tree_unflatten +spec_t = _pytree.TreeSpec diff --git a/.venv/Lib/site-packages/torch/_library/custom_ops.py b/.venv/Lib/site-packages/torch/_library/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fce6098263cffd8a82108f575e260bf6fb66f3b5 --- /dev/null +++ b/.venv/Lib/site-packages/torch/_library/custom_ops.py @@ -0,0 +1,835 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +import logging +import weakref +from contextlib import contextmanager +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import _C, _ops, Tensor +from torch.utils._exposed_in import exposed_in + +from . import autograd, utils + + +device_types_t = Optional[Union[str, Sequence[str]]] +log = logging.getLogger(__name__) + + +@exposed_in("torch.library") +def custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, +) -> Callable: + """Wraps a function into custom operator. + + Reasons why you may want to create a custom op include: + - Wrapping a third-party library or custom kernel to work with PyTorch + subsystems like Autograd. + - Preventing torch.compile/export/FX tracing from peeking inside your function. + + This API is used as a decorator around a function (please see examples). + The provided function must have type hints; these are needed to interface + with PyTorch's various subsystems. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + device_types (None | str | Sequence[str]): The device type(s) the function + is valid for. If no device type is provided, then the function + is used as the default implementation for all device types. + Examples: "cpu", "cuda". + When registering a device-specific implementation for an operator that accepts no Tensors, + we require the operator to have a "device: torch.device argument". + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + .. note:: + We recommend not passing in a ``schema`` arg and instead letting us infer + it from the type annotations. It is error-prone to write your own schema. + You may wish to provide your own schema if our interpretation of + the type annotation is not what you want. + For more info on how to write a schema string, see + `here `_ + + Examples:: + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> @custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that only works for one device type. + >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") + >>> def numpy_sin_cpu(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin_cpu(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that mutates an input + >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") + >>> def numpy_sin_inplace(x: Tensor) -> None: + >>> x_np = x.numpy() + >>> np.sin(x_np, out=x_np) + >>> + >>> x = torch.randn(3) + >>> expected = x.sin() + >>> numpy_sin_inplace(x) + >>> assert torch.allclose(x, expected) + >>> + >>> # Example of a factory function + >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") + >>> def bar(device: torch.device) -> Tensor: + >>> return torch.ones(3) + >>> + >>> bar("cpu") + + """ + + def inner(fn): + import torch + + if schema is None: + schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) + else: + schema_str = schema + + namespace, opname = name.split("::") + result = CustomOpDef(namespace, opname, schema_str, fn) + if schema is not None: + # Check that schema's alias annotations match those of `mutates_args`. + expected = set() + for arg in result._opoverload._schema.arguments: + if arg.alias_info is not None and arg.alias_info.is_write: + expected.add(arg.name) + if expected != set(mutates_args): + raise ValueError( + f"Attempted to create a custom op with `mutates_args={mutates_args}` " + f"and `schema={schema}. The schema suggests that the op mutates {expected}" + f"which is different from what was provided to us in `mutates_args`. " + f"Please make these consistent." + ) + result.register_kernel(device_types)(fn) + return result + + if fn is None: + return inner + return inner(fn) + + +class CustomOpDef: + """CustomOpDef is a wrapper around a function that turns it into a custom op. + + It has various methods for registering additional behavior for this + custom op. + + You should not instantiate CustomOpDef directly; instead, use the + :func:`torch.library.custom_op` API. + """ + + def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None: + # Fields used to interface with the PyTorch dispatcher + self._namespace = namespace + self._name = name + self._schema = schema + + self._init_fn = fn + + self._backend_fns: Dict[Union[str, None], Callable] = {} + self._abstract_fn: Optional[Callable] = None + self._setup_context_fn: Optional[Callable] = None + self._backward_fn: Optional[Callable] = None + self._torch_dispatch_fns: Dict[type, Callable] = {} + self._vmap_fn: Optional[Callable] = None + + self._lib = get_library_allowing_overwrite(self._namespace, self._name) + self._register_to_dispatcher() + self._disabled_kernel: Set = set() + OPDEFS[self._qualname] = self + + @property + def _qualname(self) -> str: + return f"{self._namespace}::{self._name}" + + def __repr__(self) -> str: + return f"" + + @contextmanager + def set_kernel_enabled(self, device_type: str, enabled: bool = True): + """ + Disable or re-enable an already registered kernel for this custom operator. + + If the kernel is already disabled/enabled, this is a no-op. + + Note: + If a kernel is first disabled and then registered, it is disabled until enabled again. + + Args: + device_type (str): The device type to disable/enable the kernel for. + disable (bool): Whether to disable or enable the kernel. + + Example: + >>> inp = torch.randn(1) + >>> + >>> # define custom op `f`. + >>> @custom_op("mylib::f", mutates_args=()) + >>> def f(x: Tensor) -> Tensor: + >>> return torch.zeros(1) + >>> + >>> print(f(inp)) # tensor([0.]), default kernel + >>> + >>> @f.register_kernel("cpu") + >>> def _(x): + >>> return torch.ones(1) + >>> + >>> print(f(inp)) # tensor([1.]), CPU kernel + >>> + >>> # temporarily disable the CPU kernel + >>> with f.set_kernel_enabled("cpu", enabled = False): + >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled + + """ + action = "enable" if enabled else "disable" + originally_disabled = device_type in self._disabled_kernel + if device_type not in self._backend_fns: + log.warning( + "Attempted to %s kernel for %s but no kernel was registered for this device type.", + action, + device_type, + ) + + if not enabled: + if originally_disabled: + log.warning( + "Attempted to disable kernel for %s but it was already disabled.", + device_type, + ) + else: + self._disabled_kernel.add(device_type) + else: # enable the kernel + if not originally_disabled: + log.warning( + "Attempted to enable kernel for %s but it was already enabled.", + device_type, + ) + else: + self._disabled_kernel.remove(device_type) + + try: + yield + finally: + # restore original state + if originally_disabled: + self._disabled_kernel.add(device_type) + else: + self._disabled_kernel.discard(device_type) + + def register_kernel( + self, device_types: device_types_t, fn: Optional[Callable] = None, / + ) -> Callable: + """Register an implementation for a device type for this operator. + + Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + This API may be used as a decorator. + + Args: + fn (Callable): The function to register as the implementation for + the given device types. + device_types (str | Sequence[str]): The device device_types to register an impl to. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> # Create a custom op that works on cpu + >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> # Add implementations for the cuda device + >>> @numpy_sin.register_kernel("cuda") + >>> def _(x): + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x_cpu = torch.randn(3) + >>> x_cuda = x_cpu.cuda() + >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) + >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) + + """ + + def inner(fn): + if device_types is None or isinstance(device_types, str): + dtypes: List[Union[str, None]] = [device_types] + else: + dtypes = list(device_types) + for device_type in dtypes: + if device_type not in self._backend_fns: + + def backend_impl(*args, **kwargs): + # Checks the assumption that outputs cannot alias + # inputs or other outputs. + storages = { + id(tensor.untyped_storage()) + for tensor in iter_tensors(args, kwargs) + } + + result = self._backend_fns[device_type](*args, **kwargs) + + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + for tensor in iter_tensors(tuple_result, {}): + key = id(tensor.untyped_storage()) + if id(tensor.untyped_storage()) in storages: + fn = self._backend_fns[device_type] + module = inspect.getmodule(fn) + raise RuntimeError( + f"{self._name} (with implementation in {module}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + storages.add(key) + return result + + if device_type is None: + self._lib.impl( + self._name, backend_impl, "CompositeExplicitAutograd" + ) + else: + self._lib.impl( + self._name, + backend_impl, + _C._dispatch_key_for_device(device_type), + ) + + # Wrap function to choose between the default implementation or the device-specific + # implementation depending on if the kernel is disabled. + @torch._disable_dynamo + def wrapped_fn(*args, **kwargs): + if device_type in self._disabled_kernel: + return self._init_fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + self._backend_fns[device_type] = wrapped_fn + return fn + + if device_types is not None and not utils.has_tensor_arg( + self._opoverload._schema + ): + device_arg_index = utils.get_device_arg_index(self._opoverload._schema) + if device_arg_index is None: + raise ValueError( + "Functions without tensor inputs are required to have a `device: torch.device` argument" + ) + self._register_backend_select_dispatcher(device_arg_index) + + # See NOTE: [Supporting decorator and non-decorator usage] + if fn is None: + return inner + return inner(fn) + + def register_fake(self, fn: Callable, /) -> Callable: + r"""Register a FakeTensor implementation for this custom op. + + This is necessary to get the operator to work efficiently with torch.compile. + + The Fake impl (sometimes also known as a meta kernel or abstract impl) + specifies the behavior of this operator on Tensors that carry no data. + Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + Please see :func:`torch.library.impl_abstract` for more details. + + Args: + fn (Callable): The function to register as the FakeTensor + implementation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch.library.custom_op("mylib::linear", mutates_args=()) + >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> return (x @ weight.t()) + bias + >>> + >>> @linear.register_fake + >>> def _(x, weight, bias): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> return x.new_empty(x.size(0), weight.size(0)) + >>> + >>> x = torch.randn(2, 2) + >>> weight = torch.randn(2, 2) + >>> bias = torch.randn(2) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias)) + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch.library.custom_op("mylib::nonzero", mutates_args=()) + >>> def nonzero(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + >>> + >>> @nonzero.register_fake + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> x = torch.tensor([0, 1, 2, 0, 0, 1]) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(nonzero, fullgraph=True)(x) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, x.nonzero()) + + """ + self._abstract_fn = fn + return fn + + def register_torch_dispatch( + self, torch_dispatch_class: Any, fn: Optional[Callable] = None, / + ) -> Callable: + r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. + + This allows for open registration to specify the behavior between the operator + and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` + or the operator directly. + + Please see :func:`torch.library.register_torch_dispatch` for examples and more details. + """ + + def register(fn): + if torch_dispatch_class not in self._torch_dispatch_fns: + + def inner(*args, **kwargs): + return self._torch_dispatch_fns[torch_dispatch_class]( + *args, **kwargs + ) + + self._lib._register_torch_dispatch_rule( + self._name, torch_dispatch_class, inner + ) + self._torch_dispatch_fns[torch_dispatch_class] = fn + return fn + + if fn is None: + return register + else: + return register(fn) + + def register_autograd( + self, + backward: Callable, + /, + *, + setup_context: Optional[Callable] = None, + ) -> None: + r"""Register a backward formula for this custom op. + + In order for an operator to work with autograd, you need to register + a backward formula: + 1. You must tell us how to compute gradients during the backward pass + by providing us a "backward" function. + 2. If you need any values from the forward to compute gradients, you can + use `setup_context` to save values for backward. + + ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``: + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + The ``ctx`` object is `the same ctx object `_ used by + :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the + same as :meth:`torch.autograd.Function.backward`. + + ``setup_context(ctx, inputs, output)`` runs during the forward pass. + Please save quantities needed for backward onto the ``ctx`` object via + either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` + or assigning them as attributes of ``ctx``. If your custom op has + kwarg-only arguments, we expect the signature of ``setup_context`` + to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. + + Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, + they may not directly access :meth:`torch.Tensor.data_ptr` and they must + not depend on or mutate global state. If you need a non-traceable backward, + you can make it a separate custom_op that you call inside ``backward_fn``. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, output) -> Tensor: + >>> x, = inputs + >>> ctx.save_for_backward(x) + >>> + >>> def backward(ctx, grad): + >>> x, = ctx.saved_tensors + >>> return grad * x.cos() + >>> + >>> numpy_sin.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_sin(x) + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, x.cos()) + >>> + >>> # Example with a keyword-only arg + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = x_np * val + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: + >>> ctx.val = keyword_only_inputs["val"] + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.val + >>> + >>> numpy_mul.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_mul(x, val=3.14) + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) + + """ + schema = self._opoverload._schema + if not utils.is_functional_schema(schema): + raise RuntimeError( + f"Cannot register autograd formula for non-functional operator " + f"{self} with schema {schema}. Please create " + f"a functional operator and register an autograd formula for that." + ) + + self._backward_fn = backward + self._setup_context_fn = setup_context + + def _register_to_dispatcher(self) -> None: + lib = self._lib + schema_str = self._name + self._schema + cpp_schema = _C.parse_schema(schema_str) + if utils.has_kwarg_only_tensors(cpp_schema): + # If you want to support this, the progression is: + # - supporting kwarg-only Tensors that are non-differentiable + # - supporting kwarg-only Tensors (regardless of differentiability) + raise NotImplementedError( + f"custom_op with kwarg-only Tensor args. Please make your " + f"tensors not kwarg-only. Got: {schema_str}" + ) + + lib.define( + schema_str, + tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order], + ) + self._opoverload = utils.lookup_op(self._qualname) + + def fake_impl(*args, **kwargs): + if self._abstract_fn is None: + if utils.can_generate_trivial_fake_impl(self._opoverload): + return None + raise RuntimeError( + f"There was no fake impl registered for {self}. " + f"This is necessary for torch.compile/export/fx tracing to work. " + f"Please use `{self._init_fn.__name__}.register_fake` to add an " + f"fake impl." + ) + return self._abstract_fn(*args, **kwargs) + + lib._register_fake(self._name, fake_impl, _stacklevel=4) + + autograd_impl = autograd.make_autograd_impl(self._opoverload, self) + lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) + + schema = self._opoverload._schema + if schema.is_mutable: + + def adinplaceorview_impl(keyset, *args, **kwargs): + for arg, val in utils.zip_schema(schema, args, kwargs): + if not arg.alias_info: + continue + if not arg.alias_info.is_write: + continue + if isinstance(val, Tensor): + torch.autograd.graph.increment_version(val) + elif isinstance(val, (tuple, list)): + for v in val: + if isinstance(v, Tensor): + torch.autograd.graph.increment_version(v) + with _C._AutoDispatchBelowADInplaceOrView(): + return self._opoverload.redispatch( + keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs + ) + + lib.impl( + self._name, + adinplaceorview_impl, + "ADInplaceOrView", + with_keyset=True, + ) + + def _register_backend_select_dispatcher(self, device_arg_index: int): + """ + Switch on the device argument to select the correct backend to dispatch to. + """ + + def backend_select(keyset, *args, **kwargs): + device = args[device_arg_index].type + if device not in self._backend_fns: + raise RuntimeError( + f"{self._name} does not have a kernel registered for {device}. " + "Please use register_kernel to do so." + ) + dispatch_key = _C._dispatch_key_for_device(device) + dispatch_key = getattr(_C.DispatchKey, dispatch_key) + return self._opoverload.redispatch( + _C.DispatchKeySet(dispatch_key), *args, **kwargs + ) + + self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True) + + def __call__(self, *args, **kwargs): + return self._opoverload(*args, **kwargs) + + def register_vmap( + self, + func: Optional[Callable] = None, + ): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator. + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> numpy_cube.register_vmap(numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @numpy_mul.register_vmap + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + """ + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def register(func): + need_register = self._vmap_fn is None + self._vmap_fn = func + + if need_register: + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, self._vmap_fn, self._opoverload, *args, **kwargs + ) + + self._lib.impl( + self._name, wrapped_func, "FuncTorchBatched", with_keyset=True + ) + + if func is None: + return register + else: + return register(func) + + +# NOTE: [Supporting decorator and non-decorator usage] +# +# Some APIs may be both used as a decorator and not as a decorator. +# For example: +# +# >>> def fn(x): +# >>> return x.sin() +# >>> +# >>> # Usage 1: not as a decorator +# >>> numpy_sin.register_kernel("cuda", fn) +# >>> +# >>> # Usage 2: as a decorator +# >>> @numpy_sin.register_kernel("cuda") +# >>> def fn2(x): +# >>> return x.sin +# +# The way we support this is that `register_kernel` accepts an optional `fn`. +# If `fn` is provided (Usage 1), then we know that the user is using it not +# as a decorator. +# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a +# decorator. + + +OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {} +OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + + +def get_library_allowing_overwrite( + namespace: str, name: str +) -> "torch.library.Library": + qualname = f"{namespace}::{name}" + + if qualname in OPDEF_TO_LIB: + OPDEF_TO_LIB[qualname]._destroy() + del OPDEF_TO_LIB[qualname] + + lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 + OPDEF_TO_LIB[qualname] = lib + return lib + + +def iter_tensors( + args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1 +) -> Iterator[Tensor]: + def check(arg): + if isinstance(arg, Tensor): + yield arg + elif allowed_nesting > 0 and isinstance(arg, (tuple, list)): + yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1) + + for arg in args: + yield from check(arg) + for kwarg in kwargs.values(): + yield from check(kwarg) + + +def _maybe_get_opdef( + op: Union[CustomOpDef, _ops.OpOverload, str] +) -> Optional[CustomOpDef]: + if isinstance(op, CustomOpDef): + return op + if isinstance(op, _ops.OpOverload): + op = op._name + assert isinstance(op, str) + if op in OPDEFS: + return OPDEFS[op] + return None