Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc +0 -0
- .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py +296 -0
- .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py +321 -0
- .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +149 -0
- .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +109 -0
- .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py +0 -0
- .venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py +32 -0
- .venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py +91 -0
- .venv/Lib/site-packages/torch/_inductor/codegen/common.py +2167 -0
- .venv/Lib/site-packages/torch/_inductor/codegen/cpp.py +0 -0
- .venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py +1043 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py +0 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py +173 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +204 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py +203 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py +219 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py +129 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py +209 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py +227 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py +598 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py +243 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py +452 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py +208 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py +173 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py +189 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py +189 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +177 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +193 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +220 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +204 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +220 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py +52 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py +44 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py +44 -0
- .venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py +0 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/__init__.py +1 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/bmm.py +192 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/conv.py +679 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py +1843 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py +570 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/mm.py +776 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py +466 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py +248 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py +311 -0
- .venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py +87 -0
.venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc
ADDED
|
Binary file (92.1 kB). View file
|
|
|
.venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
.venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc
ADDED
|
Binary file (35.9 kB). View file
|
|
|
.venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc
ADDED
|
Binary file (9.92 kB). View file
|
|
|
.venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MMRankingA100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 166912
|
| 27 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
|
| 61 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 62 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 63 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 64 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 65 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 66 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
|
| 67 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 68 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 69 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 70 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 71 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 72 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 73 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 74 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 75 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
|
| 76 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 77 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
|
| 78 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
|
| 79 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
|
| 80 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 81 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
|
| 82 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 83 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
|
| 84 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 85 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
|
| 86 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
|
| 87 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
|
| 88 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 89 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 90 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 91 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 92 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 93 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 94 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
|
| 95 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
|
| 96 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 97 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 98 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 99 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 100 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 101 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 102 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 103 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 104 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 105 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 106 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 107 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
|
| 108 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
|
| 109 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 110 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
|
| 111 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 112 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 113 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
|
| 114 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 115 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
|
| 116 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 117 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 118 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
|
| 119 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 120 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
|
| 121 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 122 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 123 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 124 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 125 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 126 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 127 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 128 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 129 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 130 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 131 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
|
| 132 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 133 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 134 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
|
| 135 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
|
| 136 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
|
| 137 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 138 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 139 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 140 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 141 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 142 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
|
| 143 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
|
| 144 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 145 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 146 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 147 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 148 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 149 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
|
| 150 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 151 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 152 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 153 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 154 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
|
| 155 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 156 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 157 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 158 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 159 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 160 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
|
| 161 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 162 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
|
| 163 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 164 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 165 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 166 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 167 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 168 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
|
| 169 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 170 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 171 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 172 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
|
| 173 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 174 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
|
| 175 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
|
| 176 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
|
| 177 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
|
| 178 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
|
| 179 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 180 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 181 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 182 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 183 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 184 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 185 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 186 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 187 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
|
| 188 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 189 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 190 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 191 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 192 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 193 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
|
| 194 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 195 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 196 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 197 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 198 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 199 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 200 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 201 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 202 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 203 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 204 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 205 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 206 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 207 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
|
| 208 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 209 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 210 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 211 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 212 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 213 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 214 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 215 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 216 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 217 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 218 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 219 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 220 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 221 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 222 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 223 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 224 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 225 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
|
| 226 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 227 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
|
| 228 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 229 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 230 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 231 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
|
| 232 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 233 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 234 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 235 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 236 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 237 |
+
|
| 238 |
+
def get_name(self) -> str:
|
| 239 |
+
return 'mm'
|
| 240 |
+
|
| 241 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 242 |
+
if context.get_value('arith_intensity') <= 52.6245059967041:
|
| 243 |
+
if context.get_value('n') <= 34.0:
|
| 244 |
+
if context.get_value('n') <= 18.0:
|
| 245 |
+
if context.get_value('k*n') <= 312.0:
|
| 246 |
+
return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
|
| 247 |
+
else:
|
| 248 |
+
if context.get_value('k') <= 40.0:
|
| 249 |
+
return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
|
| 250 |
+
else:
|
| 251 |
+
return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
|
| 252 |
+
else:
|
| 253 |
+
if context.get_value('mat1_stride_0') <= 20.0:
|
| 254 |
+
return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
|
| 255 |
+
else:
|
| 256 |
+
if context.get_value('k') <= 68.0:
|
| 257 |
+
return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
|
| 258 |
+
else:
|
| 259 |
+
return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
|
| 260 |
+
else:
|
| 261 |
+
if context.get_value('k') <= 35.0:
|
| 262 |
+
if context.get_value('k') <= 18.0:
|
| 263 |
+
if context.get_value('m*n') <= 19505152.0:
|
| 264 |
+
return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
|
| 265 |
+
else:
|
| 266 |
+
return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
|
| 267 |
+
else:
|
| 268 |
+
if context.get_value('n') <= 68.0:
|
| 269 |
+
return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
|
| 270 |
+
else:
|
| 271 |
+
return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
|
| 272 |
+
else:
|
| 273 |
+
if context.get_value('m*n') <= 309760.0:
|
| 274 |
+
return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
|
| 275 |
+
else:
|
| 276 |
+
if context.get_value('n') <= 72.0:
|
| 277 |
+
return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
|
| 278 |
+
else:
|
| 279 |
+
return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
|
| 280 |
+
else:
|
| 281 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 282 |
+
if context.get_value('m*n') <= 815360.0:
|
| 283 |
+
if context.get_value('k') <= 1184.0:
|
| 284 |
+
return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
|
| 285 |
+
else:
|
| 286 |
+
return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
|
| 287 |
+
else:
|
| 288 |
+
if context.get_value('arith_intensity') <= 187.23922729492188:
|
| 289 |
+
if context.get_value('mat1_stride_0') <= 198.0:
|
| 290 |
+
return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
|
| 291 |
+
else:
|
| 292 |
+
return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
|
| 293 |
+
else:
|
| 294 |
+
return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
|
| 295 |
+
else:
|
| 296 |
+
return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]
|
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MMRankingH100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 232448
|
| 27 |
+
and str(metadata.device_capa) == "(9, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 60 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 61 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
|
| 62 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 63 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 64 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
|
| 65 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 66 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 67 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 68 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 69 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 70 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 71 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 72 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 73 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
|
| 74 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 75 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
|
| 76 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
|
| 77 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 78 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
|
| 79 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 80 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
|
| 81 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 82 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
|
| 83 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 84 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
|
| 85 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
|
| 86 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
|
| 87 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 88 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 89 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 90 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 91 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 92 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 93 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
|
| 94 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
|
| 95 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 96 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 97 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 98 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 99 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 100 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 101 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 102 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 103 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 104 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 105 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 106 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
|
| 107 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
|
| 108 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 109 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
|
| 110 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 111 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 112 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
|
| 113 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 114 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
|
| 115 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 116 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 117 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
|
| 118 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 119 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
|
| 120 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 121 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 122 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 123 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 124 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 125 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 126 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
|
| 127 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 128 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
|
| 129 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 130 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 131 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
|
| 132 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
|
| 133 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
|
| 134 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
|
| 135 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
|
| 136 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 137 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
|
| 138 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
|
| 139 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
|
| 140 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
|
| 141 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
|
| 142 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 143 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 144 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 145 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 146 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
|
| 147 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 148 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 149 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 150 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
|
| 151 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
|
| 152 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
|
| 153 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
|
| 154 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 155 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 156 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 157 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 158 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
|
| 159 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
|
| 160 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 161 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 162 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
|
| 163 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 164 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 165 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 166 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
|
| 167 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
|
| 168 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
|
| 169 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
|
| 170 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 171 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 172 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
|
| 173 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
|
| 174 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
|
| 175 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 176 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 177 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 178 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
|
| 179 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
|
| 180 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 181 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
|
| 182 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
|
| 183 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
|
| 184 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
|
| 185 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
|
| 186 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 187 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
|
| 188 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
|
| 189 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 190 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
|
| 191 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
|
| 192 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
|
| 193 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
|
| 194 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
|
| 195 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
|
| 196 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
|
| 197 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
|
| 198 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
|
| 199 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
|
| 200 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
|
| 201 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
|
| 202 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
|
| 203 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
|
| 204 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
|
| 205 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
|
| 206 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
|
| 207 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
|
| 208 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 209 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
|
| 210 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
|
| 211 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
|
| 212 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
|
| 213 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
|
| 214 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
|
| 215 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
|
| 216 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
|
| 217 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
|
| 218 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
|
| 219 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
|
| 220 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
|
| 221 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
|
| 222 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
|
| 223 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 224 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
|
| 225 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
|
| 226 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
|
| 227 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
|
| 228 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 229 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
|
| 230 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
|
| 231 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
|
| 232 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
|
| 233 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
|
| 234 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
|
| 235 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
|
| 236 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
|
| 237 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 238 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 239 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
|
| 240 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
|
| 241 |
+
|
| 242 |
+
def get_name(self) -> str:
|
| 243 |
+
return 'mm'
|
| 244 |
+
|
| 245 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 246 |
+
if context.get_value('arith_intensity') <= 29.89772129058838:
|
| 247 |
+
if context.get_value('n') <= 34.0:
|
| 248 |
+
if context.get_value('n') <= 18.0:
|
| 249 |
+
if context.get_value('k*n') <= 432.0:
|
| 250 |
+
if context.get_value('arith_intensity') <= 7.8700292110443115:
|
| 251 |
+
return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
|
| 252 |
+
else:
|
| 253 |
+
return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
|
| 254 |
+
else:
|
| 255 |
+
if context.get_value('k') <= 40.0:
|
| 256 |
+
return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
|
| 257 |
+
else:
|
| 258 |
+
return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
|
| 259 |
+
else:
|
| 260 |
+
if context.get_value('mat1_stride_0') <= 40.0:
|
| 261 |
+
if context.get_value('mat1_stride_0') <= 20.0:
|
| 262 |
+
return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
|
| 263 |
+
else:
|
| 264 |
+
return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
|
| 265 |
+
else:
|
| 266 |
+
if context.get_value('mat1_stride_0') <= 68.0:
|
| 267 |
+
return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
|
| 268 |
+
else:
|
| 269 |
+
return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
|
| 270 |
+
else:
|
| 271 |
+
if context.get_value('k') <= 18.0:
|
| 272 |
+
if context.get_value('m*k') <= 528.0:
|
| 273 |
+
return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
|
| 274 |
+
else:
|
| 275 |
+
if context.get_value('n') <= 80.0:
|
| 276 |
+
return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
|
| 277 |
+
else:
|
| 278 |
+
return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
|
| 279 |
+
else:
|
| 280 |
+
if context.get_value('k') <= 36.0:
|
| 281 |
+
if context.get_value('n') <= 68.0:
|
| 282 |
+
return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
|
| 283 |
+
else:
|
| 284 |
+
return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
|
| 285 |
+
else:
|
| 286 |
+
if context.get_value('mat2_stride_0') <= 384.0:
|
| 287 |
+
return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
|
| 288 |
+
else:
|
| 289 |
+
return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
|
| 290 |
+
else:
|
| 291 |
+
if context.get_value('arith_intensity') <= 56.995582580566406:
|
| 292 |
+
if context.get_value('n') <= 68.0:
|
| 293 |
+
if context.get_value('k*n') <= 4448.0:
|
| 294 |
+
if context.get_value('m*n') <= 29626368.0:
|
| 295 |
+
return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
|
| 296 |
+
else:
|
| 297 |
+
return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
|
| 298 |
+
else:
|
| 299 |
+
if context.get_value('k') <= 348.0:
|
| 300 |
+
return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
|
| 301 |
+
else:
|
| 302 |
+
return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
|
| 303 |
+
else:
|
| 304 |
+
if context.get_value('m') <= 3264.0:
|
| 305 |
+
return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
|
| 306 |
+
else:
|
| 307 |
+
if context.get_value('k') <= 62.5:
|
| 308 |
+
return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
|
| 309 |
+
else:
|
| 310 |
+
return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
|
| 311 |
+
else:
|
| 312 |
+
if context.get_value('m*n') <= 1097728.0:
|
| 313 |
+
return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
|
| 314 |
+
else:
|
| 315 |
+
if context.get_value('m*n') <= 3244032.0:
|
| 316 |
+
return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
|
| 317 |
+
else:
|
| 318 |
+
if context.get_value('n') <= 136.0:
|
| 319 |
+
return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
|
| 320 |
+
else:
|
| 321 |
+
return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]
|
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 8 |
+
AHContext,
|
| 9 |
+
AHMetadata,
|
| 10 |
+
Choice,
|
| 11 |
+
)
|
| 12 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 13 |
+
LearnedHeuristicDecision,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MixedMMH100(LearnedHeuristicDecision):
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.choices: List[Choice] = []
|
| 21 |
+
self.fill_choices()
|
| 22 |
+
|
| 23 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 24 |
+
return (
|
| 25 |
+
metadata.name == self.get_name()
|
| 26 |
+
and metadata.shared_memory == 232448
|
| 27 |
+
and str(metadata.device_capa) == "(9, 0)"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def get_confidence_threshold(self) -> float:
|
| 31 |
+
return 0.0
|
| 32 |
+
|
| 33 |
+
def get_choice(self, idx: int) -> Optional[str]:
|
| 34 |
+
if idx < len(self.choices):
|
| 35 |
+
return self.choices[idx]
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def fill_choices(self) -> None:
|
| 39 |
+
self.choices.append('extern_fallback_mixed_mm')
|
| 40 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
|
| 41 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
|
| 42 |
+
self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 43 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 44 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
|
| 45 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
|
| 46 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 47 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
|
| 48 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
|
| 49 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
|
| 50 |
+
self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
|
| 51 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 52 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
|
| 53 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 54 |
+
self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
|
| 55 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
|
| 56 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
|
| 57 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
|
| 58 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
|
| 59 |
+
self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
|
| 60 |
+
|
| 61 |
+
def get_name(self) -> str:
|
| 62 |
+
return 'mixed_mm'
|
| 63 |
+
|
| 64 |
+
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
|
| 65 |
+
if context.get_value('arith_intensity') <= 15.988086223602295:
|
| 66 |
+
if context.get_value('n') <= 25280.0:
|
| 67 |
+
if context.get_value('n') <= 1344.0:
|
| 68 |
+
if context.get_value('mat1_stride_0') <= 7808.0:
|
| 69 |
+
return [(0.581, 7), (0.419, 6)]
|
| 70 |
+
else:
|
| 71 |
+
if context.get_value('m*n') <= 7680.0:
|
| 72 |
+
return [(0.875, 0), (0.125, 6)]
|
| 73 |
+
else:
|
| 74 |
+
return [(0.833, 0), (0.167, 7)]
|
| 75 |
+
else:
|
| 76 |
+
if context.get_value('n') <= 8512.0:
|
| 77 |
+
if str(context.get_value('mat2_dtype')) != 'torch.int8':
|
| 78 |
+
return [(0.763, 6), (0.237, 7)]
|
| 79 |
+
else:
|
| 80 |
+
return [(0.725, 7), (0.275, 6)]
|
| 81 |
+
else:
|
| 82 |
+
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
|
| 83 |
+
return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
|
| 84 |
+
else:
|
| 85 |
+
return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
|
| 86 |
+
else:
|
| 87 |
+
if context.get_value('n') <= 42254.0:
|
| 88 |
+
if context.get_value('n') <= 33856.0:
|
| 89 |
+
if context.get_value('k*n') <= 68157440.0:
|
| 90 |
+
return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
|
| 91 |
+
else:
|
| 92 |
+
return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
|
| 93 |
+
else:
|
| 94 |
+
return [(0.659, 5), (0.341, 6)]
|
| 95 |
+
else:
|
| 96 |
+
if context.get_value('k*n') <= 326052992.0:
|
| 97 |
+
if context.get_value('n') <= 55232.0:
|
| 98 |
+
return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
|
| 99 |
+
else:
|
| 100 |
+
return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
|
| 101 |
+
else:
|
| 102 |
+
if context.get_value('n') <= 57024.0:
|
| 103 |
+
return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
|
| 104 |
+
else:
|
| 105 |
+
return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
|
| 106 |
+
else:
|
| 107 |
+
if context.get_value('m*n') <= 543936.0:
|
| 108 |
+
if str(context.get_value('17LEQmLEQ32')) != 'True':
|
| 109 |
+
if context.get_value('m*n') <= 262272.0:
|
| 110 |
+
if context.get_value('n') <= 1592.5:
|
| 111 |
+
return [(0.860, 0), (0.140, 9)]
|
| 112 |
+
else:
|
| 113 |
+
return None
|
| 114 |
+
else:
|
| 115 |
+
if context.get_value('m*k') <= 1294336.0:
|
| 116 |
+
return [(0.833, 17), (0.150, 18), (0.017, 15)]
|
| 117 |
+
else:
|
| 118 |
+
return [(0.917, 17), (0.083, 8)]
|
| 119 |
+
else:
|
| 120 |
+
if context.get_value('n') <= 12416.0:
|
| 121 |
+
if context.get_value('m*n') <= 43008.0:
|
| 122 |
+
return None
|
| 123 |
+
else:
|
| 124 |
+
return [(0.853, 14), (0.147, 9)]
|
| 125 |
+
else:
|
| 126 |
+
return [(0.625, 12), (0.375, 14)]
|
| 127 |
+
else:
|
| 128 |
+
if context.get_value('m') <= 32.5:
|
| 129 |
+
if context.get_value('mat2_stride_1') <= 6656.0:
|
| 130 |
+
if context.get_value('n') <= 69184.0:
|
| 131 |
+
return [(0.611, 12), (0.361, 14), (0.028, 13)]
|
| 132 |
+
else:
|
| 133 |
+
return [(1.000, 12)]
|
| 134 |
+
else:
|
| 135 |
+
if context.get_value('mat2_stride_1') <= 20864.0:
|
| 136 |
+
return [(1.000, 12)]
|
| 137 |
+
else:
|
| 138 |
+
return [(0.958, 12), (0.042, 9)]
|
| 139 |
+
else:
|
| 140 |
+
if context.get_value('m*n') <= 1085440.0:
|
| 141 |
+
if context.get_value('n') <= 9152.0:
|
| 142 |
+
return [(1.000, 18)]
|
| 143 |
+
else:
|
| 144 |
+
return [(0.780, 18), (0.160, 16), (0.060, 20)]
|
| 145 |
+
else:
|
| 146 |
+
if context.get_value('m') <= 67.0:
|
| 147 |
+
return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
|
| 148 |
+
else:
|
| 149 |
+
return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]
|
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: B950
|
| 2 |
+
# fmt: off
|
| 3 |
+
# This file was generated by AutoHeuristic. Do not modify it manually!
|
| 4 |
+
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
|
| 5 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
|
| 6 |
+
from torch._inductor.autoheuristic.learnedheuristic_interface import (
|
| 7 |
+
LearnedHeuristicRegression,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PadMMA100(LearnedHeuristicRegression):
|
| 12 |
+
|
| 13 |
+
def __init__(self) -> None:
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
|
| 17 |
+
return (
|
| 18 |
+
metadata.name == self.get_name()
|
| 19 |
+
and metadata.shared_memory == 166912
|
| 20 |
+
and str(metadata.device_capa) == "(8, 0)"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def get_feedback(self, context: AHContext, choice: Choice) -> float:
|
| 24 |
+
context.context_dict[CHOICE_COL] = choice
|
| 25 |
+
return self.predict(context)
|
| 26 |
+
|
| 27 |
+
def get_confidence_threshold(self) -> float:
|
| 28 |
+
return 1.7025303314066
|
| 29 |
+
|
| 30 |
+
def get_name(self) -> str:
|
| 31 |
+
return 'pad_mm'
|
| 32 |
+
|
| 33 |
+
def predict(self, context: AHContext) -> float:
|
| 34 |
+
if str(context.get_value('choice')) != 'pad':
|
| 35 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 36 |
+
if context.get_value('m*n') <= 4171264.0:
|
| 37 |
+
if context.get_value('m*k') <= 3999308.0:
|
| 38 |
+
return 1.8751469764071178
|
| 39 |
+
else:
|
| 40 |
+
if str(context.get_value('n_multiple_32')) != 'True':
|
| 41 |
+
return 0.9117231355626345
|
| 42 |
+
else:
|
| 43 |
+
return 1.1607689608873861
|
| 44 |
+
else:
|
| 45 |
+
if str(context.get_value('n_multiple_2')) != 'True':
|
| 46 |
+
if str(context.get_value('using_tf32')) != 'True':
|
| 47 |
+
return 0.7430382200435992
|
| 48 |
+
else:
|
| 49 |
+
return 0.8531269794448678
|
| 50 |
+
else:
|
| 51 |
+
if str(context.get_value('k_multiple_2')) != 'True':
|
| 52 |
+
return 0.7577181972719917
|
| 53 |
+
else:
|
| 54 |
+
return 0.8977349440424219
|
| 55 |
+
else:
|
| 56 |
+
if context.get_value('m*n') <= 1299712.0:
|
| 57 |
+
return 1.1669723418995592
|
| 58 |
+
else:
|
| 59 |
+
if context.get_value('mat2_stride_1') <= 45217.5:
|
| 60 |
+
if context.get_value('m*n') <= 55884158.0:
|
| 61 |
+
return 1.0262769936909601
|
| 62 |
+
else:
|
| 63 |
+
return 1.0022677428470845
|
| 64 |
+
else:
|
| 65 |
+
if context.get_value('m') <= 18478.0:
|
| 66 |
+
return 1.1127066261894312
|
| 67 |
+
else:
|
| 68 |
+
return 1.0337740659894263
|
| 69 |
+
else:
|
| 70 |
+
if str(context.get_value('mat1_dtype')) != 'torch.float32':
|
| 71 |
+
if str(context.get_value('n_multiple_2')) != 'False':
|
| 72 |
+
if str(context.get_value('k_multiple_2')) != 'True':
|
| 73 |
+
if context.get_value('mat1_stride_0') <= 561.0:
|
| 74 |
+
return 1.2900382135142956
|
| 75 |
+
else:
|
| 76 |
+
return 1.5761737616057887
|
| 77 |
+
else:
|
| 78 |
+
if context.get_value('num_dims_needs_padding') <= 1.5:
|
| 79 |
+
return 1.0472263310239422
|
| 80 |
+
else:
|
| 81 |
+
return 1.1727673465762514
|
| 82 |
+
else:
|
| 83 |
+
if context.get_value('k') <= 28238.5:
|
| 84 |
+
if context.get_value('k/(m*n)') <= 0.00026227018679492176:
|
| 85 |
+
return 1.6770542505397175
|
| 86 |
+
else:
|
| 87 |
+
return 1.3974785435105923
|
| 88 |
+
else:
|
| 89 |
+
if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
|
| 90 |
+
return 1.3952699800111992
|
| 91 |
+
else:
|
| 92 |
+
return 1.5759286511628336
|
| 93 |
+
else:
|
| 94 |
+
if str(context.get_value('using_tf32')) != 'False':
|
| 95 |
+
if context.get_value('m*n') <= 14119424.0:
|
| 96 |
+
return 0.8875772670422478
|
| 97 |
+
else:
|
| 98 |
+
if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
|
| 99 |
+
return 1.1467728924377265
|
| 100 |
+
else:
|
| 101 |
+
return 1.215842963532998
|
| 102 |
+
else:
|
| 103 |
+
if context.get_value('arith_intensity') <= 396.8774871826172:
|
| 104 |
+
return 0.89940161869551
|
| 105 |
+
else:
|
| 106 |
+
if context.get_value('mat2_stride_1') <= 45217.5:
|
| 107 |
+
return 0.9964328169353532
|
| 108 |
+
else:
|
| 109 |
+
return 0.9493479238294826
|
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
|
| 9 |
+
# "...
|
| 10 |
+
# from ..codecache import CudaKernelParamCache
|
| 11 |
+
# ..."
|
| 12 |
+
# In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
|
| 16 |
+
if torch.version.hip is None and not force_hipify:
|
| 17 |
+
return source_codes
|
| 18 |
+
|
| 19 |
+
def c2_repl(m):
|
| 20 |
+
return PYTORCH_MAP[m.group(0)]
|
| 21 |
+
|
| 22 |
+
# We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
|
| 23 |
+
# it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
|
| 24 |
+
# keyword at the beginning of code line. However, this can happen in codegen,
|
| 25 |
+
# which will cause the pattern to not match.
|
| 26 |
+
|
| 27 |
+
# Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
|
| 28 |
+
# we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
|
| 29 |
+
RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
|
| 30 |
+
|
| 31 |
+
source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
|
| 32 |
+
return source_codes
|
.venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def cuda_kernel_driver() -> str:
|
| 8 |
+
source_codes = """
|
| 9 |
+
#define CUDA_DRIVER_CHECK(EXPR) \\
|
| 10 |
+
do { \\
|
| 11 |
+
CUresult code = EXPR; \\
|
| 12 |
+
const char *msg; \\
|
| 13 |
+
cuGetErrorString(code, &msg); \\
|
| 14 |
+
if (code != CUDA_SUCCESS) { \\
|
| 15 |
+
throw std::runtime_error( \\
|
| 16 |
+
std::string("CUDA driver error: ") + \\
|
| 17 |
+
std::string(msg)); \\
|
| 18 |
+
} \\
|
| 19 |
+
} while (0);
|
| 20 |
+
|
| 21 |
+
namespace {
|
| 22 |
+
|
| 23 |
+
struct Grid {
|
| 24 |
+
Grid(uint32_t x, uint32_t y, uint32_t z)
|
| 25 |
+
: grid_x(x), grid_y(y), grid_z(z) {}
|
| 26 |
+
uint32_t grid_x;
|
| 27 |
+
uint32_t grid_y;
|
| 28 |
+
uint32_t grid_z;
|
| 29 |
+
|
| 30 |
+
bool is_non_zero() {
|
| 31 |
+
return grid_x > 0 && grid_y > 0 && grid_z > 0;
|
| 32 |
+
}
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
} // anonymous namespace
|
| 36 |
+
|
| 37 |
+
static inline CUfunction loadKernel(
|
| 38 |
+
std::string filePath,
|
| 39 |
+
const std::string &funcName,
|
| 40 |
+
uint32_t sharedMemBytes,
|
| 41 |
+
const std::optional<std::string> &cubinDir = std::nullopt) {
|
| 42 |
+
if (cubinDir) {
|
| 43 |
+
std::filesystem::path p1{*cubinDir};
|
| 44 |
+
std::filesystem::path p2{filePath};
|
| 45 |
+
filePath = (p1 / p2.filename()).string();
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
CUmodule mod;
|
| 49 |
+
CUfunction func;
|
| 50 |
+
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
|
| 51 |
+
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
|
| 52 |
+
if (sharedMemBytes > 0) {
|
| 53 |
+
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
|
| 54 |
+
func,
|
| 55 |
+
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 56 |
+
sharedMemBytes
|
| 57 |
+
))
|
| 58 |
+
}
|
| 59 |
+
return func;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
static inline void launchKernel(
|
| 63 |
+
CUfunction func,
|
| 64 |
+
uint32_t gridX,
|
| 65 |
+
uint32_t gridY,
|
| 66 |
+
uint32_t gridZ,
|
| 67 |
+
uint32_t numWarps,
|
| 68 |
+
uint32_t sharedMemBytes,
|
| 69 |
+
void* args[],
|
| 70 |
+
cudaStream_t stream) {
|
| 71 |
+
CUDA_DRIVER_CHECK(cuLaunchKernel(
|
| 72 |
+
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
|
| 73 |
+
));
|
| 74 |
+
}
|
| 75 |
+
"""
|
| 76 |
+
if torch.version.hip is not None:
|
| 77 |
+
# Adjusting the warp size to GPU supported wavefront size on AMD GPU
|
| 78 |
+
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
|
| 79 |
+
source_codes = source_codes.replace(
|
| 80 |
+
"32*numWarps", str(prop.warp_size) + "*numWarps"
|
| 81 |
+
)
|
| 82 |
+
return source_codes
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def cuda_kernel_header() -> str:
|
| 86 |
+
source_codes = """
|
| 87 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 88 |
+
#include <c10/cuda/CUDAStream.h>
|
| 89 |
+
#include <ATen/cuda/EmptyTensor.h>
|
| 90 |
+
"""
|
| 91 |
+
return source_codes
|
.venv/Lib/site-packages/torch/_inductor/codegen/common.py
ADDED
|
@@ -0,0 +1,2167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import dataclasses
|
| 4 |
+
import functools
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
import operator
|
| 9 |
+
import re
|
| 10 |
+
from enum import auto, Enum
|
| 11 |
+
from itertools import chain
|
| 12 |
+
from typing import (
|
| 13 |
+
Any,
|
| 14 |
+
Callable,
|
| 15 |
+
ClassVar,
|
| 16 |
+
Dict,
|
| 17 |
+
List,
|
| 18 |
+
NamedTuple,
|
| 19 |
+
Optional,
|
| 20 |
+
Tuple,
|
| 21 |
+
Union,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
import sympy
|
| 25 |
+
from sympy.printing.printer import Printer
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.fx
|
| 29 |
+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
| 30 |
+
from torch.utils import _pytree as pytree
|
| 31 |
+
from torch.utils._ordered_set import OrderedSet
|
| 32 |
+
from torch.utils._sympy.numbers import int_oo
|
| 33 |
+
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
| 34 |
+
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
| 35 |
+
|
| 36 |
+
from .. import config, metrics
|
| 37 |
+
from ..utils import (
|
| 38 |
+
DeferredLineBase,
|
| 39 |
+
generate_assert,
|
| 40 |
+
IndentedBuffer,
|
| 41 |
+
sympy_dot,
|
| 42 |
+
sympy_subs,
|
| 43 |
+
unique,
|
| 44 |
+
)
|
| 45 |
+
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def data_type_logger(msg):
|
| 52 |
+
if schedule_log.isEnabledFor(logging.DEBUG):
|
| 53 |
+
schedule_log.debug("Data type propagation: %s", msg)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclasses.dataclass
|
| 57 |
+
class WorkspaceArg:
|
| 58 |
+
"""A temporary buffer used for a single kernel, then discarded.
|
| 59 |
+
|
| 60 |
+
Not registered as a traditional buffer since there are no users,
|
| 61 |
+
so it would be dead code eliminated.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
nbytes: sympy.Expr
|
| 65 |
+
zero_fill: bool
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclasses.dataclass
|
| 69 |
+
class TensorArg:
|
| 70 |
+
name: str
|
| 71 |
+
buffer: str
|
| 72 |
+
dtype: torch.dtype
|
| 73 |
+
offset: sympy.Expr = sympy.Integer(0) # c++ only
|
| 74 |
+
alias_of: Optional[str] = None # halide only
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclasses.dataclass
|
| 78 |
+
class SizeArg:
|
| 79 |
+
name: str
|
| 80 |
+
expr: sympy.Expr
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def alias_of(self):
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclasses.dataclass
|
| 88 |
+
class DeviceCodegen:
|
| 89 |
+
scheduling: Any
|
| 90 |
+
wrapper_codegen: type
|
| 91 |
+
cpp_wrapper_codegen: type = type(None)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
|
| 95 |
+
|
| 96 |
+
device_codegens: Dict[str, DeviceCodegen] = {}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DeviceOpOverrides:
|
| 100 |
+
def import_get_raw_stream_as(self, name):
|
| 101 |
+
raise NotImplementedError
|
| 102 |
+
|
| 103 |
+
def set_device(self, device_idx):
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
def synchronize(self):
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
def device_guard(self, device_idx):
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
| 117 |
+
# For any new backend looking to integrate with Inductor, customization of these two main
|
| 118 |
+
# parts are necessary to generate its specific code.
|
| 119 |
+
#
|
| 120 |
+
# Kernel code generation is determined by different Scheduling. Consequently, a new
|
| 121 |
+
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
|
| 122 |
+
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
|
| 123 |
+
#
|
| 124 |
+
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
|
| 125 |
+
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
|
| 126 |
+
# and override specific member functions to create backend-specific Python wrapper code.
|
| 127 |
+
#
|
| 128 |
+
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
|
| 129 |
+
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
|
| 130 |
+
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
|
| 131 |
+
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
|
| 132 |
+
# register_backend_for_device, to equip a new backend at runtime.
|
| 133 |
+
#
|
| 134 |
+
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
|
| 135 |
+
# This backend can be used as a reference:
|
| 136 |
+
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
|
| 137 |
+
def register_backend_for_device(
|
| 138 |
+
device: str,
|
| 139 |
+
device_scheduling: Any,
|
| 140 |
+
device_wrapper_codegen: type,
|
| 141 |
+
device_cpp_wrapper_codegen: type = type(None),
|
| 142 |
+
):
|
| 143 |
+
device_codegens[device] = DeviceCodegen(
|
| 144 |
+
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class BackendFeature(Enum):
|
| 149 |
+
FOREACH = auto()
|
| 150 |
+
BUCKETIZE = auto()
|
| 151 |
+
INPLACE_BUFFERS = auto()
|
| 152 |
+
MASKED_SCATTER_WITH_INDEX = auto()
|
| 153 |
+
SCAN = auto()
|
| 154 |
+
SORT = auto()
|
| 155 |
+
TUPLE_REDUCTION = auto()
|
| 156 |
+
PREFER_STORE_LOOP_ORDER = auto()
|
| 157 |
+
TRITON_TEMPLATES = auto()
|
| 158 |
+
REDUCE_TO_SINGLE_ELEMENT = auto()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_backend_features(device: Union[torch.device, str]):
|
| 162 |
+
init_backend_registration()
|
| 163 |
+
if isinstance(device, torch.device):
|
| 164 |
+
device_type = device.type
|
| 165 |
+
else:
|
| 166 |
+
assert isinstance(device, str)
|
| 167 |
+
device_type = device
|
| 168 |
+
device = torch.device(device_type)
|
| 169 |
+
scheduling = get_scheduling_for_device(device_type)
|
| 170 |
+
return scheduling(None).get_backend_features(device)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def has_backend_feature(device, feature):
|
| 174 |
+
"""See also V.graph.has_feature"""
|
| 175 |
+
assert isinstance(feature, BackendFeature)
|
| 176 |
+
return feature in get_backend_features(device)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_scheduling_for_device(device: str):
|
| 180 |
+
return device_codegens[device].scheduling if device in device_codegens else None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
|
| 184 |
+
if device in device_codegens:
|
| 185 |
+
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
|
| 186 |
+
return (
|
| 187 |
+
wrapper_codegen_obj.cpp_wrapper_codegen
|
| 188 |
+
if cpp_wrapper
|
| 189 |
+
else wrapper_codegen_obj.wrapper_codegen
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@functools.lru_cache(None)
|
| 196 |
+
def init_backend_registration():
|
| 197 |
+
from .cpp import CppScheduling
|
| 198 |
+
from .cpp_wrapper_cpu import CppWrapperCpu
|
| 199 |
+
from .cpp_wrapper_cuda import CppWrapperCuda
|
| 200 |
+
from .cuda_combined_scheduling import CUDACombinedScheduling
|
| 201 |
+
from .halide import HalideScheduling
|
| 202 |
+
from .triton import TritonScheduling
|
| 203 |
+
from .wrapper import WrapperCodeGen
|
| 204 |
+
|
| 205 |
+
if get_scheduling_for_device("cpu") is None:
|
| 206 |
+
cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
|
| 207 |
+
register_backend_for_device(
|
| 208 |
+
"cpu",
|
| 209 |
+
lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
|
| 210 |
+
WrapperCodeGen,
|
| 211 |
+
CppWrapperCpu,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if get_scheduling_for_device("cuda") is None:
|
| 215 |
+
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
|
| 216 |
+
cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
|
| 217 |
+
register_backend_for_device(
|
| 218 |
+
"cuda",
|
| 219 |
+
lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
|
| 220 |
+
WrapperCodeGen,
|
| 221 |
+
CppWrapperCuda,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if get_scheduling_for_device("xpu") is None:
|
| 225 |
+
register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
|
| 226 |
+
|
| 227 |
+
private_backend = torch._C._get_privateuse1_backend_name()
|
| 228 |
+
if (
|
| 229 |
+
private_backend != "privateuseone"
|
| 230 |
+
and get_scheduling_for_device(private_backend) is None
|
| 231 |
+
):
|
| 232 |
+
from torch.utils.backend_registration import _get_custom_mod_func
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
device_scheduling = _get_custom_mod_func("Scheduling")
|
| 236 |
+
wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
|
| 237 |
+
cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
|
| 238 |
+
if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
|
| 239 |
+
register_backend_for_device(
|
| 240 |
+
private_backend,
|
| 241 |
+
device_scheduling,
|
| 242 |
+
wrapper_codegen,
|
| 243 |
+
cpp_wrapper_codegen,
|
| 244 |
+
)
|
| 245 |
+
except RuntimeError:
|
| 246 |
+
pass
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
|
| 250 |
+
from ..ir import FlexibleLayout
|
| 251 |
+
|
| 252 |
+
# added contiguous index prevents reordering
|
| 253 |
+
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
|
| 257 |
+
device_op_overrides_dict[device] = device_op_overrides
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def get_device_op_overrides(device: str):
|
| 261 |
+
assert isinstance(device, str)
|
| 262 |
+
|
| 263 |
+
if not device_op_overrides_dict.keys():
|
| 264 |
+
from .cuda import device_op_overrides # noqa: F401
|
| 265 |
+
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
|
| 266 |
+
|
| 267 |
+
if device in device_op_overrides_dict.keys():
|
| 268 |
+
return device_op_overrides_dict[device]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@functools.lru_cache(None)
|
| 272 |
+
def boolean_ops():
|
| 273 |
+
return (
|
| 274 |
+
"isinf",
|
| 275 |
+
"isnan",
|
| 276 |
+
"logical_not",
|
| 277 |
+
"signbit",
|
| 278 |
+
"le",
|
| 279 |
+
"lt",
|
| 280 |
+
"ge",
|
| 281 |
+
"gt",
|
| 282 |
+
"eq",
|
| 283 |
+
"ne",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
DTYPE_TO_COMPUTATION_DTYPE = {
|
| 288 |
+
torch.bfloat16: torch.float,
|
| 289 |
+
torch.float16: torch.float,
|
| 290 |
+
**{
|
| 291 |
+
dtype: dtype
|
| 292 |
+
for dtype in [
|
| 293 |
+
torch.bool,
|
| 294 |
+
torch.float32,
|
| 295 |
+
torch.float64,
|
| 296 |
+
torch.int8,
|
| 297 |
+
torch.int16,
|
| 298 |
+
torch.int32,
|
| 299 |
+
torch.int64,
|
| 300 |
+
torch.uint8,
|
| 301 |
+
torch.uint16,
|
| 302 |
+
torch.uint32,
|
| 303 |
+
torch.uint64,
|
| 304 |
+
]
|
| 305 |
+
},
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def deduce_output_dtype_by_name(
|
| 310 |
+
op_name: str,
|
| 311 |
+
*args,
|
| 312 |
+
**kwargs,
|
| 313 |
+
) -> Optional[torch.dtype]:
|
| 314 |
+
"""
|
| 315 |
+
Given op name and a list of input dtypes, deduce the output dtype
|
| 316 |
+
"""
|
| 317 |
+
if op_name in boolean_ops():
|
| 318 |
+
return torch.bool
|
| 319 |
+
elif op_name in (
|
| 320 |
+
"to_dtype",
|
| 321 |
+
"index_expr",
|
| 322 |
+
):
|
| 323 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-1]
|
| 324 |
+
elif op_name in (
|
| 325 |
+
"rand",
|
| 326 |
+
"randn",
|
| 327 |
+
):
|
| 328 |
+
return torch.float
|
| 329 |
+
elif op_name in (
|
| 330 |
+
"get_index",
|
| 331 |
+
"randint64",
|
| 332 |
+
"load_seed",
|
| 333 |
+
):
|
| 334 |
+
return torch.int64
|
| 335 |
+
elif op_name == "reduction":
|
| 336 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[1]
|
| 337 |
+
elif op_name == "constant":
|
| 338 |
+
dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
|
| 339 |
+
return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index]
|
| 340 |
+
elif op_name in (
|
| 341 |
+
"load",
|
| 342 |
+
"store",
|
| 343 |
+
"store_reduction",
|
| 344 |
+
):
|
| 345 |
+
buf_name = args[1]
|
| 346 |
+
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
| 347 |
+
elif op_name == "to_dtype_bitcast":
|
| 348 |
+
return kwargs["dtype"] if "dtype" in kwargs else args[-2]
|
| 349 |
+
return None
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class DataTypePropagation:
|
| 353 |
+
def __init__(self, body) -> None:
|
| 354 |
+
self.body = body
|
| 355 |
+
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
|
| 356 |
+
"root": body.root_block.graph
|
| 357 |
+
}
|
| 358 |
+
for k, v in body.subblocks.items():
|
| 359 |
+
self.graphs[k] = v.graph
|
| 360 |
+
|
| 361 |
+
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
|
| 362 |
+
inputs = node.all_input_nodes
|
| 363 |
+
input_nodes = [
|
| 364 |
+
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
|
| 365 |
+
]
|
| 366 |
+
if len(input_nodes) == 0:
|
| 367 |
+
return None
|
| 368 |
+
|
| 369 |
+
all_input_nodes_propagated = all(
|
| 370 |
+
OptimizationContext.key in n.meta
|
| 371 |
+
and n.meta[OptimizationContext.key].dtype is not None
|
| 372 |
+
for n in input_nodes
|
| 373 |
+
)
|
| 374 |
+
if not all_input_nodes_propagated:
|
| 375 |
+
return None
|
| 376 |
+
|
| 377 |
+
return functools.reduce(
|
| 378 |
+
torch.promote_types,
|
| 379 |
+
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
|
| 383 |
+
sub_graph = self.graphs[node.target]
|
| 384 |
+
dtype = self.propagate_graph(sub_graph)
|
| 385 |
+
assert dtype
|
| 386 |
+
return dtype
|
| 387 |
+
|
| 388 |
+
def deduce_node_dtype(self, node: torch.fx.Node):
|
| 389 |
+
if node.op == "placeholder":
|
| 390 |
+
return None
|
| 391 |
+
|
| 392 |
+
if node.target == "output" and len(node.args) != 1:
|
| 393 |
+
# we can infer output node if it only have 1 arg
|
| 394 |
+
return None
|
| 395 |
+
|
| 396 |
+
if node.target == operator.getitem:
|
| 397 |
+
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
| 398 |
+
|
| 399 |
+
assert isinstance(node.target, str)
|
| 400 |
+
|
| 401 |
+
if node.target.startswith("masked_subblock"):
|
| 402 |
+
return self.deduce_node_dtype_by_subgraph(node)
|
| 403 |
+
|
| 404 |
+
if (
|
| 405 |
+
output_dtype := deduce_output_dtype_by_name(
|
| 406 |
+
node.target,
|
| 407 |
+
*node.args,
|
| 408 |
+
**node.kwargs,
|
| 409 |
+
)
|
| 410 |
+
) is not None:
|
| 411 |
+
return output_dtype
|
| 412 |
+
|
| 413 |
+
return self.deduce_node_dtype_by_inputs(node)
|
| 414 |
+
|
| 415 |
+
def propagate_graph(self, graph: torch.fx.Graph):
|
| 416 |
+
assert graph.nodes
|
| 417 |
+
graph_dtype = None
|
| 418 |
+
# For masked_subblock, we use output's dtype to represent
|
| 419 |
+
# the dtype of this subgraph. For other cases, graph_dtype
|
| 420 |
+
# might be None
|
| 421 |
+
for node in graph.nodes:
|
| 422 |
+
if OptimizationContext.key in node.meta:
|
| 423 |
+
opt_ctx = node.meta[OptimizationContext.key]
|
| 424 |
+
else:
|
| 425 |
+
opt_ctx = OptimizationContext()
|
| 426 |
+
|
| 427 |
+
opt_ctx.dtype = self.deduce_node_dtype(node)
|
| 428 |
+
node.meta[OptimizationContext.key] = opt_ctx
|
| 429 |
+
if node.target == "output":
|
| 430 |
+
graph_dtype = opt_ctx.dtype
|
| 431 |
+
return graph_dtype
|
| 432 |
+
|
| 433 |
+
def propagate(self):
|
| 434 |
+
self.propagate_graph(self.graphs["root"])
|
| 435 |
+
|
| 436 |
+
@classmethod
|
| 437 |
+
def propagate_loopbody(cls, body):
|
| 438 |
+
return cls(body).propagate()
|
| 439 |
+
|
| 440 |
+
@classmethod
|
| 441 |
+
def propagate_scheduler_node(cls, node):
|
| 442 |
+
from ..loop_body import LoopBody
|
| 443 |
+
from ..scheduler import SchedulerNode
|
| 444 |
+
|
| 445 |
+
assert isinstance(node, SchedulerNode)
|
| 446 |
+
assert isinstance(node._body, LoopBody)
|
| 447 |
+
DataTypePropagation.propagate_loopbody(node._body)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# This printer contains rules that are supposed to be generic for both C/C++ and
|
| 451 |
+
# Python
|
| 452 |
+
class ExprPrinter(Printer):
|
| 453 |
+
@staticmethod
|
| 454 |
+
def paren(string):
|
| 455 |
+
def all_in_parens(string):
|
| 456 |
+
if string[0] != "(" or len(string) < 2:
|
| 457 |
+
return False
|
| 458 |
+
count = 1
|
| 459 |
+
for i, char in enumerate(string[1:]):
|
| 460 |
+
if char == "(":
|
| 461 |
+
count += 1
|
| 462 |
+
elif char == ")":
|
| 463 |
+
count -= 1
|
| 464 |
+
if count == 0 and i != len(string) - 2:
|
| 465 |
+
return False
|
| 466 |
+
assert count == 0
|
| 467 |
+
return True
|
| 468 |
+
|
| 469 |
+
if (
|
| 470 |
+
isinstance(string, CSEVariable)
|
| 471 |
+
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
|
| 472 |
+
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
|
| 473 |
+
or string == ""
|
| 474 |
+
):
|
| 475 |
+
return string
|
| 476 |
+
# don't put extra parens for strings that are already wrapped in parens
|
| 477 |
+
if all_in_parens(string):
|
| 478 |
+
return string
|
| 479 |
+
return f"({string})"
|
| 480 |
+
|
| 481 |
+
def _print_Relational(self, expr):
|
| 482 |
+
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
| 483 |
+
|
| 484 |
+
def _print_Mul(self, expr):
|
| 485 |
+
return "*".join(map(self.paren, map(self._print, expr.args)))
|
| 486 |
+
|
| 487 |
+
def _print_Add(self, expr):
|
| 488 |
+
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
| 489 |
+
|
| 490 |
+
# NB: this is OK to put here, because Mod is only defined for positive
|
| 491 |
+
# numbers, and so across C/Python its behavior is consistent
|
| 492 |
+
def _print_Mod(self, expr):
|
| 493 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 494 |
+
|
| 495 |
+
def _print_FloatTrueDiv(self, expr):
|
| 496 |
+
lhs, rhs = expr.args
|
| 497 |
+
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
| 498 |
+
|
| 499 |
+
def _print_CleanDiv(self, expr):
|
| 500 |
+
return self._print_FloorDiv(expr)
|
| 501 |
+
|
| 502 |
+
def _print_Identity(self, expr):
|
| 503 |
+
return self._print(expr.args[0])
|
| 504 |
+
|
| 505 |
+
def _print_GreaterThan(self, expr):
|
| 506 |
+
# GreaterThan: >=
|
| 507 |
+
# StrictlyGreaterThan: >
|
| 508 |
+
# Go figure...
|
| 509 |
+
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
| 510 |
+
|
| 511 |
+
# NB: The C implementation is injected into codegen at
|
| 512 |
+
# torch/_inductor/codegen/wrapper.py
|
| 513 |
+
def _print_align(self, expr):
|
| 514 |
+
assert len(expr.args) == 1
|
| 515 |
+
return f"align({self._print(expr.args[0])})"
|
| 516 |
+
|
| 517 |
+
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
| 518 |
+
# any explicit intervention. We print it just like x * x, notably, we
|
| 519 |
+
# never generate sympy.Pow with floats.
|
| 520 |
+
#
|
| 521 |
+
# NB: this pow by natural, you should never have used builtin sympy.pow
|
| 522 |
+
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
| 523 |
+
# means exp is guaranteed to be integer.
|
| 524 |
+
def _print_Pow(self, expr):
|
| 525 |
+
base, exp = expr.args
|
| 526 |
+
base = self._print(base)
|
| 527 |
+
assert exp == int(exp), exp
|
| 528 |
+
exp = int(exp)
|
| 529 |
+
assert exp >= 0
|
| 530 |
+
if exp > 0:
|
| 531 |
+
return "*".join([self.paren(base)] * exp)
|
| 532 |
+
else: # exp == 0
|
| 533 |
+
return "1"
|
| 534 |
+
|
| 535 |
+
# Explicit NotImplemented functions are to prevent default sympy printing
|
| 536 |
+
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
| 537 |
+
# message is better here because it tells you which printer class it needs
|
| 538 |
+
# to go in.
|
| 539 |
+
|
| 540 |
+
def _print_ToFloat(self, expr):
|
| 541 |
+
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
| 542 |
+
|
| 543 |
+
def _print_Infinity(self, expr):
|
| 544 |
+
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
| 545 |
+
|
| 546 |
+
def _print_NegativeInfinity(self, expr):
|
| 547 |
+
raise NotImplementedError(
|
| 548 |
+
f"_print_NegativeInfinity not implemented for {type(self)}"
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def _print_FloorDiv(self, expr):
|
| 552 |
+
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
| 553 |
+
|
| 554 |
+
def _print_PythonMod(self, expr):
|
| 555 |
+
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
| 556 |
+
|
| 557 |
+
def _print_IntTrueDiv(self, expr):
|
| 558 |
+
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
| 559 |
+
|
| 560 |
+
def _print_PowByNatural(self, expr):
|
| 561 |
+
raise NotImplementedError(
|
| 562 |
+
f"_print_PowByNatural not implemented for {type(self)}"
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
def _print_FloatPow(self, expr):
|
| 566 |
+
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
| 567 |
+
|
| 568 |
+
def _print_TruncToInt(self, expr):
|
| 569 |
+
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
| 570 |
+
|
| 571 |
+
def _print_RoundToInt(self, expr):
|
| 572 |
+
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
| 573 |
+
|
| 574 |
+
def _print_RoundDecimal(self, expr):
|
| 575 |
+
raise NotImplementedError(
|
| 576 |
+
f"_print_RoundDecimal not implemented for {type(self)}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# NB: Some float operations are INTENTIONALLY not implemented for
|
| 580 |
+
# printers. You can implement them as a quick unblock, but it is better
|
| 581 |
+
# to ask yourself why we haven't done this computation in the Tensor
|
| 582 |
+
# universe instead
|
| 583 |
+
|
| 584 |
+
def _print_TruncToFloat(self, expr):
|
| 585 |
+
raise NotImplementedError(
|
| 586 |
+
f"_print_TruncToFloat not implemented for {type(self)}"
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def doprint(self, expr, *, simplify: bool = True):
|
| 590 |
+
# TODO: why are people passing strings to the printer here :think:
|
| 591 |
+
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
| 592 |
+
expr = V.graph.sizevars.simplify(expr)
|
| 593 |
+
return super().doprint(expr)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class PythonPrinter(ExprPrinter):
|
| 597 |
+
def _print_ToFloat(self, expr):
|
| 598 |
+
assert len(expr.args) == 1
|
| 599 |
+
return f"float({self._print(expr.args[0])})"
|
| 600 |
+
|
| 601 |
+
def _print_ModularIndexing(self, expr):
|
| 602 |
+
x, div, mod = expr.args
|
| 603 |
+
x = self.paren(self.doprint(x))
|
| 604 |
+
div = self.paren(self.doprint(div))
|
| 605 |
+
mod = self.paren(self.doprint(mod))
|
| 606 |
+
if div != "1":
|
| 607 |
+
x = f"({x} // {div})"
|
| 608 |
+
return f"{x} % {mod}"
|
| 609 |
+
|
| 610 |
+
def _print_Infinity(self, expr):
|
| 611 |
+
return "math.inf"
|
| 612 |
+
|
| 613 |
+
def _print_NegativeInfinity(self, expr):
|
| 614 |
+
return "-math.inf"
|
| 615 |
+
|
| 616 |
+
# WARNING: this is dangerous for Triton, which has C-style modulus
|
| 617 |
+
def _print_PythonMod(self, expr):
|
| 618 |
+
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
| 619 |
+
|
| 620 |
+
# WARNING: this is dangerous for Triton, which has C-style modulus
|
| 621 |
+
def _print_FloorDiv(self, expr):
|
| 622 |
+
x, div = expr.args
|
| 623 |
+
x = self.paren(self.doprint(x))
|
| 624 |
+
div = self.paren(self.doprint(div))
|
| 625 |
+
return f"({x} // {div})"
|
| 626 |
+
|
| 627 |
+
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
| 628 |
+
# does a special algorithm
|
| 629 |
+
def _print_IntTrueDiv(self, expr):
|
| 630 |
+
lhs, rhs = expr.args
|
| 631 |
+
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
| 632 |
+
|
| 633 |
+
def _helper_sqrt(self, expr):
|
| 634 |
+
return f"math.sqrt({self._print(expr)})"
|
| 635 |
+
|
| 636 |
+
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
| 637 |
+
return self._helper_sqrt(expr.args[0])
|
| 638 |
+
|
| 639 |
+
def _print_FloatPow(self, expr):
|
| 640 |
+
base, exp = expr.args
|
| 641 |
+
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
| 642 |
+
|
| 643 |
+
# TODO: Not sure this works with Triton, even when base/exp are integral
|
| 644 |
+
def _print_PowByNatural(self, expr):
|
| 645 |
+
base, exp = expr.args
|
| 646 |
+
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
| 647 |
+
|
| 648 |
+
def _print_floor(self, expr):
|
| 649 |
+
assert len(expr.args) == 1
|
| 650 |
+
return f"math.floor({self._print(expr.args[0])})"
|
| 651 |
+
|
| 652 |
+
def _print_FloorToInt(self, expr):
|
| 653 |
+
assert len(expr.args) == 1
|
| 654 |
+
return f"math.floor({self._print(expr.args[0])})"
|
| 655 |
+
|
| 656 |
+
def _print_TruncToInt(self, expr):
|
| 657 |
+
assert len(expr.args) == 1
|
| 658 |
+
# This also could have been int(), they'll do the same thing for float
|
| 659 |
+
return f"math.trunc({self._print(expr.args[0])})"
|
| 660 |
+
|
| 661 |
+
def _print_ceiling(self, expr):
|
| 662 |
+
assert len(expr.args) == 1
|
| 663 |
+
return f"math.ceil({self._print(expr.args[0])})"
|
| 664 |
+
|
| 665 |
+
def _print_CeilToInt(self, expr):
|
| 666 |
+
assert len(expr.args) == 1
|
| 667 |
+
return f"math.ceil({self._print(expr.args[0])})"
|
| 668 |
+
|
| 669 |
+
def _print_Abs(self, expr):
|
| 670 |
+
assert len(expr.args) == 1
|
| 671 |
+
return f"abs({self._print(expr.args[0])})"
|
| 672 |
+
|
| 673 |
+
# NB: It's expected that we've made explicit any promotion in the sympy
|
| 674 |
+
# expression, so it doesn't matter that Python max/min doesn't perform
|
| 675 |
+
# promotion
|
| 676 |
+
def _print_Max(self, expr):
|
| 677 |
+
assert len(expr.args) >= 2
|
| 678 |
+
return f"max({', '.join(map(self._print, expr.args))})"
|
| 679 |
+
|
| 680 |
+
def _print_Min(self, expr):
|
| 681 |
+
assert len(expr.args) >= 2
|
| 682 |
+
return f"min({', '.join(map(self._print, expr.args))})"
|
| 683 |
+
|
| 684 |
+
def _print_OpaqueUnaryFn_cos(self, expr):
|
| 685 |
+
assert len(expr.args) == 1
|
| 686 |
+
return f"math.cos({self._print(expr.args[0])})"
|
| 687 |
+
|
| 688 |
+
def _print_OpaqueUnaryFn_cosh(self, expr):
|
| 689 |
+
assert len(expr.args) == 1
|
| 690 |
+
return f"math.cosh({self._print(expr.args[0])})"
|
| 691 |
+
|
| 692 |
+
def _print_OpaqueUnaryFn_acos(self, expr):
|
| 693 |
+
assert len(expr.args) == 1
|
| 694 |
+
return f"math.acos({self._print(expr.args[0])})"
|
| 695 |
+
|
| 696 |
+
def _print_OpaqueUnaryFn_sin(self, expr):
|
| 697 |
+
assert len(expr.args) == 1
|
| 698 |
+
return f"math.sin({self._print(expr.args[0])})"
|
| 699 |
+
|
| 700 |
+
def _print_OpaqueUnaryFn_sinh(self, expr):
|
| 701 |
+
assert len(expr.args) == 1
|
| 702 |
+
return f"math.sinh({self._print(expr.args[0])})"
|
| 703 |
+
|
| 704 |
+
def _print_OpaqueUnaryFn_asin(self, expr):
|
| 705 |
+
assert len(expr.args) == 1
|
| 706 |
+
return f"math.asin({self._print(expr.args[0])})"
|
| 707 |
+
|
| 708 |
+
def _print_OpaqueUnaryFn_tan(self, expr):
|
| 709 |
+
assert len(expr.args) == 1
|
| 710 |
+
return f"math.tan({self._print(expr.args[0])})"
|
| 711 |
+
|
| 712 |
+
def _print_OpaqueUnaryFn_tanh(self, expr):
|
| 713 |
+
assert len(expr.args) == 1
|
| 714 |
+
return f"math.tanh({self._print(expr.args[0])})"
|
| 715 |
+
|
| 716 |
+
def _print_OpaqueUnaryFn_atan(self, expr):
|
| 717 |
+
assert len(expr.args) == 1
|
| 718 |
+
return f"math.atan({self._print(expr.args[0])})"
|
| 719 |
+
|
| 720 |
+
def _print_RoundToInt(self, expr):
|
| 721 |
+
assert len(expr.args) == 1
|
| 722 |
+
return f"round({self._print(expr.args[0])})"
|
| 723 |
+
|
| 724 |
+
def _print_RoundDecimal(self, expr):
|
| 725 |
+
assert len(expr.args) == 2
|
| 726 |
+
number, ndigits = expr.args
|
| 727 |
+
assert isinstance(ndigits, sympy.Integer)
|
| 728 |
+
return f"round({self._print(number)}, {ndigits})"
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class OpOverrides:
|
| 732 |
+
def __init__(self, parent):
|
| 733 |
+
super().__init__()
|
| 734 |
+
self._parent = parent
|
| 735 |
+
|
| 736 |
+
def __getattr__(self, item):
|
| 737 |
+
return getattr(self._parent, item)
|
| 738 |
+
|
| 739 |
+
@staticmethod
|
| 740 |
+
def identity(value):
|
| 741 |
+
# used to trigger cse
|
| 742 |
+
return value
|
| 743 |
+
|
| 744 |
+
@staticmethod
|
| 745 |
+
def constant(value, dtype):
|
| 746 |
+
return repr(value)
|
| 747 |
+
|
| 748 |
+
@staticmethod
|
| 749 |
+
def reciprocal(x):
|
| 750 |
+
return ops.truediv(ops.constant(1, torch.int32), x)
|
| 751 |
+
|
| 752 |
+
@staticmethod
|
| 753 |
+
def square(x):
|
| 754 |
+
return ops.mul(x, x)
|
| 755 |
+
|
| 756 |
+
@staticmethod
|
| 757 |
+
def erfc(x):
|
| 758 |
+
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
|
| 759 |
+
|
| 760 |
+
@staticmethod
|
| 761 |
+
def erfcx(x):
|
| 762 |
+
return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
|
| 763 |
+
|
| 764 |
+
@staticmethod
|
| 765 |
+
def expm1(x):
|
| 766 |
+
return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
|
| 767 |
+
|
| 768 |
+
@staticmethod
|
| 769 |
+
def log10(x):
|
| 770 |
+
return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
|
| 771 |
+
|
| 772 |
+
@staticmethod
|
| 773 |
+
def log2(x):
|
| 774 |
+
return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
|
| 775 |
+
|
| 776 |
+
@staticmethod
|
| 777 |
+
def exp2(x):
|
| 778 |
+
return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
|
| 779 |
+
|
| 780 |
+
@staticmethod
|
| 781 |
+
def log1p(x):
|
| 782 |
+
return ops.log(ops.add(x, ops.constant(1, torch.int32)))
|
| 783 |
+
|
| 784 |
+
@staticmethod
|
| 785 |
+
def sigmoid(x):
|
| 786 |
+
one = ops.constant(1, torch.int32)
|
| 787 |
+
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
|
| 788 |
+
|
| 789 |
+
@staticmethod
|
| 790 |
+
def libdevice_sigmoid(x):
|
| 791 |
+
one = ops.constant(1, torch.int32)
|
| 792 |
+
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
|
| 793 |
+
|
| 794 |
+
@staticmethod
|
| 795 |
+
def relu(x):
|
| 796 |
+
return ops.maximum(x, ops.constant(0, torch.int32))
|
| 797 |
+
|
| 798 |
+
@staticmethod
|
| 799 |
+
def libdevice_abs(x):
|
| 800 |
+
return ops.abs(x)
|
| 801 |
+
|
| 802 |
+
@staticmethod
|
| 803 |
+
def libdevice_sqrt(x):
|
| 804 |
+
return ops.sqrt(x)
|
| 805 |
+
|
| 806 |
+
@staticmethod
|
| 807 |
+
def libdevice_cos(x):
|
| 808 |
+
return ops.cos(x)
|
| 809 |
+
|
| 810 |
+
@staticmethod
|
| 811 |
+
def libdevice_sin(x):
|
| 812 |
+
return ops.sin(x)
|
| 813 |
+
|
| 814 |
+
@staticmethod
|
| 815 |
+
def libdevice_log(x):
|
| 816 |
+
return ops.log(x)
|
| 817 |
+
|
| 818 |
+
@staticmethod
|
| 819 |
+
def libdevice_exp(x):
|
| 820 |
+
return ops.exp(x)
|
| 821 |
+
|
| 822 |
+
@staticmethod
|
| 823 |
+
def bitwise_not(x):
|
| 824 |
+
return f"~{ExprPrinter.paren(x)}"
|
| 825 |
+
|
| 826 |
+
@staticmethod
|
| 827 |
+
def logical_not(a):
|
| 828 |
+
return f"{ExprPrinter.paren(a)} == 0"
|
| 829 |
+
|
| 830 |
+
@staticmethod
|
| 831 |
+
def bitwise_and(x, y):
|
| 832 |
+
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
| 833 |
+
|
| 834 |
+
@staticmethod
|
| 835 |
+
def bitwise_or(x, y):
|
| 836 |
+
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
| 837 |
+
|
| 838 |
+
@staticmethod
|
| 839 |
+
def bitwise_xor(x, y):
|
| 840 |
+
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
| 841 |
+
|
| 842 |
+
@staticmethod
|
| 843 |
+
def bitwise_left_shift(x, y):
|
| 844 |
+
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
| 845 |
+
|
| 846 |
+
@staticmethod
|
| 847 |
+
def bitwise_right_shift(x, y):
|
| 848 |
+
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
| 849 |
+
|
| 850 |
+
@staticmethod
|
| 851 |
+
def remainder(a, b):
|
| 852 |
+
r = ops.mod(a, b)
|
| 853 |
+
cond = ops.and_(
|
| 854 |
+
ops.ne(r, ops.constant(0, torch.int32)),
|
| 855 |
+
ops.ne(ops.signbit(r), ops.signbit(b)),
|
| 856 |
+
)
|
| 857 |
+
return ops.where(cond, ops.add(r, b), r)
|
| 858 |
+
|
| 859 |
+
@staticmethod
|
| 860 |
+
def trunc_to_int(a, dtype):
|
| 861 |
+
return ops.to_dtype(ops.trunc(a), dtype)
|
| 862 |
+
|
| 863 |
+
@staticmethod
|
| 864 |
+
def floor_to_int(a, dtype):
|
| 865 |
+
return ops.to_dtype(ops.floor(a), dtype)
|
| 866 |
+
|
| 867 |
+
@staticmethod
|
| 868 |
+
def ceil_to_int(a, dtype):
|
| 869 |
+
return ops.to_dtype(ops.ceil(a), dtype)
|
| 870 |
+
|
| 871 |
+
@staticmethod
|
| 872 |
+
def round_to_int(a, dtype):
|
| 873 |
+
return ops.to_dtype(ops.round(a), dtype)
|
| 874 |
+
|
| 875 |
+
@staticmethod
|
| 876 |
+
def int_truediv(a, b):
|
| 877 |
+
# TODO: this is wrong
|
| 878 |
+
# TODO: an easy bandaid is to generate runtime asserts that it's
|
| 879 |
+
# <= 2**53, which is when this equation is correct
|
| 880 |
+
return ops.truediv(a, b)
|
| 881 |
+
|
| 882 |
+
@staticmethod
|
| 883 |
+
def load_seed(name, offset):
|
| 884 |
+
return ops.load(name, sympy.Integer(offset))
|
| 885 |
+
|
| 886 |
+
@classmethod
|
| 887 |
+
def _initialize_pointwise_overrides(cls, target):
|
| 888 |
+
assert target in {"triton", "cpp", "cppvec"}, target
|
| 889 |
+
|
| 890 |
+
for funcname, data in pointwise_overrides_data.items():
|
| 891 |
+
impl = getattr(data, target)
|
| 892 |
+
if impl is None:
|
| 893 |
+
continue
|
| 894 |
+
setattr(cls, funcname, staticmethod(impl))
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
@dataclasses.dataclass
|
| 898 |
+
class OverridesData:
|
| 899 |
+
name: str
|
| 900 |
+
cpp: Callable[..., str]
|
| 901 |
+
# None when not impl in libdevice/triton
|
| 902 |
+
triton: Optional[Callable[..., str]] = None
|
| 903 |
+
# None when not impl in aten/.../vec
|
| 904 |
+
cppvec: Optional[Callable[..., str]] = None
|
| 905 |
+
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
| 906 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
# NB: if you add a new special function, don't forget to update
|
| 911 |
+
# torch._inductor.ops_handler too
|
| 912 |
+
pointwise_overrides_data: Dict[str, OverridesData] = dict(
|
| 913 |
+
airy_ai=OverridesData(
|
| 914 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 915 |
+
cpp=lambda x: f"airy_ai_forward({x})",
|
| 916 |
+
name="special_airy_ai",
|
| 917 |
+
),
|
| 918 |
+
bessel_j0=OverridesData(
|
| 919 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 920 |
+
cpp=lambda x: f"bessel_j0_forward({x})",
|
| 921 |
+
triton=lambda x: f"libdevice.j0({x})",
|
| 922 |
+
name="special_bessel_j0",
|
| 923 |
+
),
|
| 924 |
+
bessel_j1=OverridesData(
|
| 925 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 926 |
+
cpp=lambda x: f"bessel_j1_forward({x})",
|
| 927 |
+
triton=lambda x: f"libdevice.j1({x})",
|
| 928 |
+
name="special_bessel_j1",
|
| 929 |
+
),
|
| 930 |
+
bessel_y0=OverridesData(
|
| 931 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 932 |
+
cpp=lambda x: f"bessel_y0_forward({x})",
|
| 933 |
+
triton=lambda x: f"libdevice.y0({x})",
|
| 934 |
+
name="special_bessel_y0",
|
| 935 |
+
),
|
| 936 |
+
bessel_y1=OverridesData(
|
| 937 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 938 |
+
cpp=lambda x: f"bessel_y1_forward({x})",
|
| 939 |
+
triton=lambda x: f"libdevice.y1({x})",
|
| 940 |
+
name="special_bessel_y1",
|
| 941 |
+
),
|
| 942 |
+
digamma=OverridesData(
|
| 943 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 944 |
+
cpp=lambda x: f"calc_digamma({x})",
|
| 945 |
+
cppvec=lambda x: f"{x}.digamma()",
|
| 946 |
+
name="digamma",
|
| 947 |
+
),
|
| 948 |
+
# no cpp nor triton implementation for entr, it is defined as decomposition
|
| 949 |
+
# erf, erfc
|
| 950 |
+
erfcx=OverridesData(
|
| 951 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 952 |
+
cpp=lambda x: f"calc_erfcx({x})",
|
| 953 |
+
triton=lambda x: f"libdevice.erfcx({x})",
|
| 954 |
+
name="special_erfcx",
|
| 955 |
+
),
|
| 956 |
+
fma=OverridesData(
|
| 957 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 958 |
+
cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
|
| 959 |
+
cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
|
| 960 |
+
triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
|
| 961 |
+
name="fma",
|
| 962 |
+
),
|
| 963 |
+
# erfinv, exp2, expit, gammaln
|
| 964 |
+
igamma=OverridesData(
|
| 965 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 966 |
+
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
| 967 |
+
name="igamma",
|
| 968 |
+
),
|
| 969 |
+
igammac=OverridesData(
|
| 970 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 971 |
+
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
| 972 |
+
name="igammac",
|
| 973 |
+
),
|
| 974 |
+
gammainc=OverridesData(
|
| 975 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 976 |
+
cpp=lambda x, y: f"calc_igamma({x}, {y})",
|
| 977 |
+
name="special_gammainc",
|
| 978 |
+
),
|
| 979 |
+
gammaincc=OverridesData(
|
| 980 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 981 |
+
cpp=lambda x, y: f"calc_igammac({x}, {y})",
|
| 982 |
+
name="special_gammaincc",
|
| 983 |
+
),
|
| 984 |
+
i0=OverridesData(
|
| 985 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 986 |
+
cpp=lambda x: f"calc_i0({x})",
|
| 987 |
+
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
| 988 |
+
cppvec=lambda x: f"{x}.i0()",
|
| 989 |
+
name="i0",
|
| 990 |
+
),
|
| 991 |
+
i0e=OverridesData(
|
| 992 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 993 |
+
cpp=lambda x: f"calc_i0e({x})",
|
| 994 |
+
cppvec=lambda x: f"{x}.i0e()",
|
| 995 |
+
name="special_i0e",
|
| 996 |
+
),
|
| 997 |
+
i1=OverridesData(
|
| 998 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 999 |
+
cpp=lambda x: f"calc_i1({x})",
|
| 1000 |
+
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
| 1001 |
+
name="special_i1",
|
| 1002 |
+
),
|
| 1003 |
+
i1e=OverridesData(
|
| 1004 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1005 |
+
cpp=lambda x: f"calc_i1e({x})",
|
| 1006 |
+
name="special_i1e",
|
| 1007 |
+
),
|
| 1008 |
+
log_ndtr=OverridesData(
|
| 1009 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1010 |
+
cpp=lambda x: f"calc_log_ndtr({x})",
|
| 1011 |
+
name="special_log_ndtr",
|
| 1012 |
+
),
|
| 1013 |
+
# logit
|
| 1014 |
+
modified_bessel_i0=OverridesData(
|
| 1015 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1016 |
+
cpp=lambda x: f"modified_bessel_i0_forward({x})",
|
| 1017 |
+
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
|
| 1018 |
+
name="special_modified_bessel_i0",
|
| 1019 |
+
),
|
| 1020 |
+
modified_bessel_i1=OverridesData(
|
| 1021 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1022 |
+
cpp=lambda x: f"modified_bessel_i1_forward({x})",
|
| 1023 |
+
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
|
| 1024 |
+
name="special_modified_bessel_i1",
|
| 1025 |
+
),
|
| 1026 |
+
modified_bessel_k0=OverridesData(
|
| 1027 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1028 |
+
cpp=lambda x: f"modified_bessel_k0_forward({x})",
|
| 1029 |
+
name="special_modified_bessel_k0",
|
| 1030 |
+
),
|
| 1031 |
+
modified_bessel_k1=OverridesData(
|
| 1032 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1033 |
+
cpp=lambda x: f"modified_bessel_k1_forward({x})",
|
| 1034 |
+
name="special_modified_bessel_k1",
|
| 1035 |
+
),
|
| 1036 |
+
# multigamma
|
| 1037 |
+
ndtr=OverridesData(
|
| 1038 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1039 |
+
cpp=lambda x: f"calc_ndtr({x})",
|
| 1040 |
+
name="special_ndtr",
|
| 1041 |
+
),
|
| 1042 |
+
ndtri=OverridesData(
|
| 1043 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1044 |
+
cpp=lambda x: f"calc_ndtri({x})",
|
| 1045 |
+
name="special_ndtri",
|
| 1046 |
+
),
|
| 1047 |
+
polygamma=OverridesData(
|
| 1048 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1049 |
+
cpp=lambda x, y: f"calc_polygamma({y}, {x})",
|
| 1050 |
+
name="polygamma",
|
| 1051 |
+
),
|
| 1052 |
+
# psi - alias to digamma
|
| 1053 |
+
# round
|
| 1054 |
+
scaled_modified_bessel_k0=OverridesData(
|
| 1055 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1056 |
+
cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
|
| 1057 |
+
name="special_scaled_modified_bessel_k0",
|
| 1058 |
+
),
|
| 1059 |
+
scaled_modified_bessel_k1=OverridesData(
|
| 1060 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1061 |
+
cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
|
| 1062 |
+
name="special_scaled_modified_bessel_k1",
|
| 1063 |
+
),
|
| 1064 |
+
# sinc
|
| 1065 |
+
spherical_bessel_j0=OverridesData(
|
| 1066 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1067 |
+
cpp=lambda x: f"spherical_bessel_j0_forward({x})",
|
| 1068 |
+
name="special_spherical_bessel_j0",
|
| 1069 |
+
),
|
| 1070 |
+
zeta=OverridesData(
|
| 1071 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1072 |
+
cpp=lambda x, y: f"zeta({x}, {y})",
|
| 1073 |
+
name="special_zeta",
|
| 1074 |
+
),
|
| 1075 |
+
chebyshev_polynomial_t=OverridesData(
|
| 1076 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1077 |
+
cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
|
| 1078 |
+
name="special_chebyshev_polynomial_t",
|
| 1079 |
+
),
|
| 1080 |
+
chebyshev_polynomial_u=OverridesData(
|
| 1081 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1082 |
+
cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
|
| 1083 |
+
name="special_chebyshev_polynomial_u",
|
| 1084 |
+
),
|
| 1085 |
+
chebyshev_polynomial_v=OverridesData(
|
| 1086 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1087 |
+
cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
|
| 1088 |
+
name="special_chebyshev_polynomial_v",
|
| 1089 |
+
),
|
| 1090 |
+
chebyshev_polynomial_w=OverridesData(
|
| 1091 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1092 |
+
cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
|
| 1093 |
+
name="special_chebyshev_polynomial_w",
|
| 1094 |
+
),
|
| 1095 |
+
legendre_polynomial_p=OverridesData(
|
| 1096 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1097 |
+
cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
|
| 1098 |
+
name="special_legendre_polynomial_p",
|
| 1099 |
+
),
|
| 1100 |
+
shifted_chebyshev_polynomial_t=OverridesData(
|
| 1101 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1102 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
|
| 1103 |
+
name="special_shifted_chebyshev_polynomial_t",
|
| 1104 |
+
),
|
| 1105 |
+
shifted_chebyshev_polynomial_u=OverridesData(
|
| 1106 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1107 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
|
| 1108 |
+
name="special_shifted_chebyshev_polynomial_u",
|
| 1109 |
+
),
|
| 1110 |
+
shifted_chebyshev_polynomial_v=OverridesData(
|
| 1111 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1112 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
|
| 1113 |
+
name="special_shifted_chebyshev_polynomial_v",
|
| 1114 |
+
),
|
| 1115 |
+
shifted_chebyshev_polynomial_w=OverridesData(
|
| 1116 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1117 |
+
cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
|
| 1118 |
+
name="special_shifted_chebyshev_polynomial_w",
|
| 1119 |
+
),
|
| 1120 |
+
hermite_polynomial_h=OverridesData(
|
| 1121 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1122 |
+
cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
|
| 1123 |
+
name="special_hermite_polynomial_h",
|
| 1124 |
+
),
|
| 1125 |
+
hermite_polynomial_he=OverridesData(
|
| 1126 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1127 |
+
cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
|
| 1128 |
+
name="special_hermite_polynomial_he",
|
| 1129 |
+
),
|
| 1130 |
+
laguerre_polynomial_l=OverridesData(
|
| 1131 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1132 |
+
cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
|
| 1133 |
+
name="special_laguerre_polynomial_l",
|
| 1134 |
+
),
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
# Use mypy to check protocol implemented correctly
|
| 1139 |
+
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
|
| 1140 |
+
return h
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
class DeferredLine(DeferredLineBase):
|
| 1144 |
+
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
| 1145 |
+
|
| 1146 |
+
def __init__(self, name, line):
|
| 1147 |
+
super().__init__(line)
|
| 1148 |
+
self.name = name
|
| 1149 |
+
assert not isinstance(line, DeferredLineBase)
|
| 1150 |
+
|
| 1151 |
+
def __call__(self):
|
| 1152 |
+
if all(
|
| 1153 |
+
self.name not in x
|
| 1154 |
+
for x in (
|
| 1155 |
+
V.graph.removed_buffers,
|
| 1156 |
+
V.kernel.removed_buffers,
|
| 1157 |
+
V.graph.inplaced_to_remove,
|
| 1158 |
+
V.kernel.inplaced_to_remove,
|
| 1159 |
+
)
|
| 1160 |
+
):
|
| 1161 |
+
return self.line
|
| 1162 |
+
return None
|
| 1163 |
+
|
| 1164 |
+
def _new_line(self, line):
|
| 1165 |
+
return DeferredLine(self.name, line)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
class BracesBuffer(IndentedBuffer):
|
| 1169 |
+
def indent(self, offset=1):
|
| 1170 |
+
@contextlib.contextmanager
|
| 1171 |
+
def ctx():
|
| 1172 |
+
for _ in range(offset):
|
| 1173 |
+
self.writeline("{")
|
| 1174 |
+
self._indent += 1
|
| 1175 |
+
for _ in range(-offset):
|
| 1176 |
+
self._indent -= 1
|
| 1177 |
+
self.writeline("}")
|
| 1178 |
+
yield
|
| 1179 |
+
for _ in range(-offset):
|
| 1180 |
+
self.writeline("{")
|
| 1181 |
+
self._indent += 1
|
| 1182 |
+
for _ in range(offset):
|
| 1183 |
+
self._indent -= 1
|
| 1184 |
+
self.writeline("}")
|
| 1185 |
+
|
| 1186 |
+
return ctx()
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
class InplacedBuffer(NamedTuple):
|
| 1190 |
+
inner_name: str
|
| 1191 |
+
other_names: List[str]
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
class KernelArgs:
|
| 1195 |
+
@staticmethod
|
| 1196 |
+
def _lookup(prefix, odict, name):
|
| 1197 |
+
assert isinstance(name, (str, sympy.Symbol))
|
| 1198 |
+
if name not in odict:
|
| 1199 |
+
odict[name] = f"{prefix}{len(odict)}"
|
| 1200 |
+
return odict[name]
|
| 1201 |
+
|
| 1202 |
+
def __init__(self, sizevars=None):
|
| 1203 |
+
self.input_buffers = {}
|
| 1204 |
+
self.output_buffers = {}
|
| 1205 |
+
self.inplace_buffers = {}
|
| 1206 |
+
self.sizevars = sizevars or {}
|
| 1207 |
+
self.workspace_arg = None
|
| 1208 |
+
|
| 1209 |
+
def __repr__(self):
|
| 1210 |
+
return "KernelArgs({})".format(
|
| 1211 |
+
", ".join(
|
| 1212 |
+
map(
|
| 1213 |
+
repr,
|
| 1214 |
+
[
|
| 1215 |
+
self.input_buffers,
|
| 1216 |
+
self.output_buffers,
|
| 1217 |
+
self.inplace_buffers,
|
| 1218 |
+
self.sizevars,
|
| 1219 |
+
],
|
| 1220 |
+
)
|
| 1221 |
+
)
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
def _buffer_is_marked_removed(self, name):
|
| 1225 |
+
return isinstance(name, str) and name.startswith("REMOVED")
|
| 1226 |
+
|
| 1227 |
+
def input(self, name):
|
| 1228 |
+
if V.graph.scheduler:
|
| 1229 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 1230 |
+
assert name not in V.graph.removed_buffers, name
|
| 1231 |
+
if name in self.output_buffers:
|
| 1232 |
+
return self.output_buffers[name]
|
| 1233 |
+
if name in self.inplace_buffers:
|
| 1234 |
+
return self.inplace_buffers[name].inner_name
|
| 1235 |
+
if name.startswith("seed"):
|
| 1236 |
+
return self._lookup("seed", self.input_buffers, name)
|
| 1237 |
+
return self._lookup("in_ptr", self.input_buffers, name)
|
| 1238 |
+
|
| 1239 |
+
def output(self, name):
|
| 1240 |
+
if V.graph.scheduler:
|
| 1241 |
+
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
| 1242 |
+
assert name not in V.graph.removed_buffers, name
|
| 1243 |
+
if name in self.inplace_buffers:
|
| 1244 |
+
return self.inplace_buffers[name].inner_name
|
| 1245 |
+
return self._lookup("out_ptr", self.output_buffers, name)
|
| 1246 |
+
|
| 1247 |
+
def make_inplace(self, input_name, output_name):
|
| 1248 |
+
assert output_name not in self.inplace_buffers
|
| 1249 |
+
if input_name in self.inplace_buffers:
|
| 1250 |
+
buf = self.inplace_buffers[input_name]
|
| 1251 |
+
buf.other_names.append(output_name)
|
| 1252 |
+
self.inplace_buffers[output_name] = buf
|
| 1253 |
+
else:
|
| 1254 |
+
buf = InplacedBuffer(
|
| 1255 |
+
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
|
| 1256 |
+
[input_name, output_name],
|
| 1257 |
+
)
|
| 1258 |
+
self.inplace_buffers[input_name] = buf
|
| 1259 |
+
self.inplace_buffers[output_name] = buf
|
| 1260 |
+
|
| 1261 |
+
def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
|
| 1262 |
+
if self.workspace_arg is None:
|
| 1263 |
+
self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
|
| 1264 |
+
return "ws_ptr", 0
|
| 1265 |
+
|
| 1266 |
+
offset = self.workspace_arg.nbytes
|
| 1267 |
+
zero_fill = zero_fill or self.workspace_arg.zero_fill
|
| 1268 |
+
self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
|
| 1269 |
+
return "ws_ptr", offset
|
| 1270 |
+
|
| 1271 |
+
def seed_offset(self, name, value):
|
| 1272 |
+
if value in self.sizevars:
|
| 1273 |
+
return self.sizevars[value]
|
| 1274 |
+
if name in self.sizevars.values():
|
| 1275 |
+
name = (
|
| 1276 |
+
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
|
| 1277 |
+
)
|
| 1278 |
+
self.sizevars[value] = name
|
| 1279 |
+
return name
|
| 1280 |
+
|
| 1281 |
+
def size(self, name):
|
| 1282 |
+
if str(name) == "seed":
|
| 1283 |
+
self.sizevars["seed"] = "seed"
|
| 1284 |
+
return "seed"
|
| 1285 |
+
return self._lookup("ks", self.sizevars, name)
|
| 1286 |
+
|
| 1287 |
+
def call_names(self):
|
| 1288 |
+
return chain(
|
| 1289 |
+
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
| 1290 |
+
)
|
| 1291 |
+
|
| 1292 |
+
def wrap_ptr_arg(self, buf, dtype):
|
| 1293 |
+
return buf
|
| 1294 |
+
|
| 1295 |
+
def wrap_size_arg(self, size):
|
| 1296 |
+
return str(size)
|
| 1297 |
+
|
| 1298 |
+
def cpp_argdefs(self):
|
| 1299 |
+
from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
|
| 1300 |
+
|
| 1301 |
+
call_args = []
|
| 1302 |
+
arg_defs = []
|
| 1303 |
+
arg_types = []
|
| 1304 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1305 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1306 |
+
continue
|
| 1307 |
+
outer = inplaced.other_names[-1]
|
| 1308 |
+
inner = inplaced.inner_name
|
| 1309 |
+
dtype = V.graph.get_dtype(outer)
|
| 1310 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1311 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 1312 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1313 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1314 |
+
for outer, inner in self.input_buffers.items():
|
| 1315 |
+
if outer in self.inplace_buffers:
|
| 1316 |
+
continue
|
| 1317 |
+
dtype = V.graph.get_dtype(outer)
|
| 1318 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1319 |
+
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
| 1320 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1321 |
+
arg_types.append(f"const {cpp_dtype}*")
|
| 1322 |
+
for outer, inner in self.output_buffers.items():
|
| 1323 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1324 |
+
continue
|
| 1325 |
+
dtype = V.graph.get_dtype(outer)
|
| 1326 |
+
cpp_dtype = DTYPE_TO_CPP[dtype]
|
| 1327 |
+
arg_defs.append(f"{cpp_dtype}* {inner}")
|
| 1328 |
+
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
| 1329 |
+
arg_types.append(f"{cpp_dtype}*")
|
| 1330 |
+
for outer, inner in self.sizevars.items():
|
| 1331 |
+
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
| 1332 |
+
call_args.append(self.wrap_size_arg(outer))
|
| 1333 |
+
arg_types.append(f"const {INDEX_TYPE}")
|
| 1334 |
+
if V.graph.wrapper_code:
|
| 1335 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1336 |
+
assert self.workspace_arg is None, "Workspace not supported on CPU "
|
| 1337 |
+
return arg_defs, call_args, arg_types
|
| 1338 |
+
|
| 1339 |
+
def python_argdefs(self):
|
| 1340 |
+
arg_defs: List[str] = []
|
| 1341 |
+
call_args: List[str] = []
|
| 1342 |
+
arg_types: List[torch.dtype] = []
|
| 1343 |
+
precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
|
| 1344 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1345 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1346 |
+
continue
|
| 1347 |
+
arg_defs.append(inplaced.inner_name)
|
| 1348 |
+
call_args.append(inplaced.other_names[-1])
|
| 1349 |
+
arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
|
| 1350 |
+
precompile_args.append(
|
| 1351 |
+
TensorArg(
|
| 1352 |
+
name=inplaced.inner_name,
|
| 1353 |
+
buffer=inplaced.other_names[-1],
|
| 1354 |
+
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
|
| 1355 |
+
)
|
| 1356 |
+
)
|
| 1357 |
+
for outer, inner in chain(
|
| 1358 |
+
self.input_buffers.items(), self.output_buffers.items()
|
| 1359 |
+
):
|
| 1360 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1361 |
+
continue
|
| 1362 |
+
arg_defs.append(inner)
|
| 1363 |
+
call_args.append(outer)
|
| 1364 |
+
arg_types.append(V.graph.get_dtype(outer))
|
| 1365 |
+
precompile_args.append(
|
| 1366 |
+
TensorArg(
|
| 1367 |
+
name=inner,
|
| 1368 |
+
buffer=outer,
|
| 1369 |
+
dtype=V.graph.get_dtype(outer),
|
| 1370 |
+
)
|
| 1371 |
+
)
|
| 1372 |
+
for outer, inner in self.sizevars.items():
|
| 1373 |
+
arg_defs.append(inner)
|
| 1374 |
+
call_args.append(outer)
|
| 1375 |
+
arg_types.append(type(outer)) # type: ignore[arg-type]
|
| 1376 |
+
precompile_args.append(SizeArg(inner, outer))
|
| 1377 |
+
if V.graph.wrapper_code:
|
| 1378 |
+
V.graph.wrapper_code.ensure_size_computed(outer)
|
| 1379 |
+
if self.workspace_arg is not None:
|
| 1380 |
+
arg_defs.append("ws_ptr")
|
| 1381 |
+
call_args.append("workspace")
|
| 1382 |
+
precompile_args.append(self.workspace_arg)
|
| 1383 |
+
return arg_defs, call_args, precompile_args, arg_types
|
| 1384 |
+
|
| 1385 |
+
def aliases(self):
|
| 1386 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1387 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1388 |
+
continue
|
| 1389 |
+
for other in inplaced.other_names:
|
| 1390 |
+
if (
|
| 1391 |
+
other in V.graph.inplaced_to_remove
|
| 1392 |
+
or other in V.kernel.inplaced_to_remove
|
| 1393 |
+
):
|
| 1394 |
+
continue
|
| 1395 |
+
if other in self.input_buffers:
|
| 1396 |
+
yield self.input_buffers[other], inplaced.inner_name
|
| 1397 |
+
if other in self.output_buffers:
|
| 1398 |
+
yield self.output_buffers[other], inplaced.inner_name
|
| 1399 |
+
|
| 1400 |
+
def is_removed(self, name):
|
| 1401 |
+
def _is_removed(name, buffers):
|
| 1402 |
+
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
|
| 1403 |
+
|
| 1404 |
+
return _is_removed(name, self.output_buffers) and _is_removed(
|
| 1405 |
+
name, self.inplace_buffers
|
| 1406 |
+
)
|
| 1407 |
+
|
| 1408 |
+
# Includes inplace buffers, excludes removed buffers. Essentially,
|
| 1409 |
+
# after you do a call into this kernel, which buffers actually contain
|
| 1410 |
+
# updated data? Modeled off of python_argdefs.
|
| 1411 |
+
def live_output_buffers(self):
|
| 1412 |
+
live_outs = OrderedSet() # type: ignore[var-annotated]
|
| 1413 |
+
for inplaced in unique(self.inplace_buffers.values()):
|
| 1414 |
+
if self._buffer_is_marked_removed(inplaced):
|
| 1415 |
+
continue
|
| 1416 |
+
live_outs.add(inplaced.other_names[-1])
|
| 1417 |
+
for outer, inner in self.output_buffers.items():
|
| 1418 |
+
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
| 1419 |
+
continue
|
| 1420 |
+
live_outs.add(outer)
|
| 1421 |
+
return live_outs
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
class CSEVariable:
|
| 1425 |
+
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
| 1426 |
+
To do so, the backends can simply overload `Kernel.create_cse_var`
|
| 1427 |
+
The "CSEVariable.update_on_args" method gives you a hook for annotations
|
| 1428 |
+
See example of TritonCSEVariable in triton.py
|
| 1429 |
+
"""
|
| 1430 |
+
|
| 1431 |
+
def __init__(self, name, bounds: ValueRanges[Any]):
|
| 1432 |
+
assert isinstance(bounds, ValueRanges)
|
| 1433 |
+
self.name = name
|
| 1434 |
+
self.bounds = bounds
|
| 1435 |
+
self.use_count = 1 # track how many tims this expression is used
|
| 1436 |
+
|
| 1437 |
+
def __str__(self):
|
| 1438 |
+
return self.name
|
| 1439 |
+
|
| 1440 |
+
def __hash__(self) -> int:
|
| 1441 |
+
return hash(self.name)
|
| 1442 |
+
|
| 1443 |
+
def __eq__(self, other) -> bool:
|
| 1444 |
+
return type(other) == type(self) and other.name == self.name
|
| 1445 |
+
|
| 1446 |
+
def update_on_args(self, name, args, kwargs):
|
| 1447 |
+
pass
|
| 1448 |
+
|
| 1449 |
+
def __repr__(self):
|
| 1450 |
+
return f"{self.__class__.__name__}({self.name!r})"
|
| 1451 |
+
|
| 1452 |
+
|
| 1453 |
+
class CppWrapperKernelArgs(KernelArgs):
|
| 1454 |
+
def wrap_ptr_arg(self, buf, dtype):
|
| 1455 |
+
from .cpp_utils import DTYPE_TO_CPP
|
| 1456 |
+
|
| 1457 |
+
if config.abi_compatible:
|
| 1458 |
+
# In the abi_compatible model, we just return the buf here.
|
| 1459 |
+
# We will form correct call args later in wrapper.generate_kernel_all.
|
| 1460 |
+
return buf
|
| 1461 |
+
else:
|
| 1462 |
+
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
|
| 1463 |
+
|
| 1464 |
+
def wrap_size_arg(self, size):
|
| 1465 |
+
return f"{size}"
|
| 1466 |
+
|
| 1467 |
+
|
| 1468 |
+
class CSE:
|
| 1469 |
+
"""Common subexpression elimination"""
|
| 1470 |
+
|
| 1471 |
+
def __init__(
|
| 1472 |
+
self,
|
| 1473 |
+
prefix="",
|
| 1474 |
+
suffix="",
|
| 1475 |
+
name_prefix="tmp",
|
| 1476 |
+
iter_buffers=None,
|
| 1477 |
+
store_cache=None,
|
| 1478 |
+
reduction_cache=None,
|
| 1479 |
+
varname_map=None,
|
| 1480 |
+
):
|
| 1481 |
+
self.prefix = prefix
|
| 1482 |
+
self.suffix = suffix
|
| 1483 |
+
self.cache = {}
|
| 1484 |
+
self.name_prefix = name_prefix
|
| 1485 |
+
self.store_cache = store_cache or {}
|
| 1486 |
+
self.reduction_cache = reduction_cache or {}
|
| 1487 |
+
self.iter_buffer_ids = iter_buffers or itertools.count()
|
| 1488 |
+
self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
|
| 1489 |
+
self.varname_map = varname_map or {}
|
| 1490 |
+
|
| 1491 |
+
def invalidate(self, keep_vars: OrderedSet[str]):
|
| 1492 |
+
for name, tmp in list(self.store_cache.items()):
|
| 1493 |
+
if tmp not in keep_vars:
|
| 1494 |
+
del self.store_cache[name]
|
| 1495 |
+
self.invalidated_stores.add(name)
|
| 1496 |
+
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
|
| 1497 |
+
|
| 1498 |
+
def clone(self):
|
| 1499 |
+
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
| 1500 |
+
return CSE(
|
| 1501 |
+
prefix=self.prefix,
|
| 1502 |
+
suffix=self.suffix,
|
| 1503 |
+
name_prefix=self.name_prefix,
|
| 1504 |
+
iter_buffers=self.iter_buffer_ids,
|
| 1505 |
+
store_cache=self.store_cache,
|
| 1506 |
+
varname_map=self.varname_map,
|
| 1507 |
+
)
|
| 1508 |
+
|
| 1509 |
+
def generate(
|
| 1510 |
+
self,
|
| 1511 |
+
buffer: IndentedBuffer,
|
| 1512 |
+
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
|
| 1513 |
+
*,
|
| 1514 |
+
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
| 1515 |
+
write=True,
|
| 1516 |
+
assignment=True,
|
| 1517 |
+
) -> CSEVariable:
|
| 1518 |
+
if isinstance(expr, OpsValue):
|
| 1519 |
+
expr = expr.value
|
| 1520 |
+
|
| 1521 |
+
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
|
| 1522 |
+
assert write or assignment
|
| 1523 |
+
if isinstance(expr, CSEVariable):
|
| 1524 |
+
# If the expressions were always created with all the information, we could
|
| 1525 |
+
# assert expr.bounds == bounds, but sometimes the expression is created
|
| 1526 |
+
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
| 1527 |
+
expr.bounds = expr.bounds.tighten(bounds)
|
| 1528 |
+
expr.use_count += 1
|
| 1529 |
+
return expr
|
| 1530 |
+
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
|
| 1531 |
+
var = self.cache.get(cache_key, None)
|
| 1532 |
+
if not var:
|
| 1533 |
+
var = self.newvar(bounds)
|
| 1534 |
+
self.cache[cache_key] = var
|
| 1535 |
+
if write:
|
| 1536 |
+
if V.kernel.current_node:
|
| 1537 |
+
V.kernel.current_node.codegen_originating_info(
|
| 1538 |
+
buffer, only_once=True
|
| 1539 |
+
)
|
| 1540 |
+
if isinstance(expr, IndentedBuffer):
|
| 1541 |
+
if assignment:
|
| 1542 |
+
buffer.writeline(f"{self.prefix}{var} =")
|
| 1543 |
+
buffer.splice(expr)
|
| 1544 |
+
buffer.writeline(self.suffix)
|
| 1545 |
+
else:
|
| 1546 |
+
if assignment:
|
| 1547 |
+
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
| 1548 |
+
else:
|
| 1549 |
+
line = f"{expr}{self.suffix}"
|
| 1550 |
+
buffer.writeline(line)
|
| 1551 |
+
else:
|
| 1552 |
+
var.bounds = var.bounds.tighten(bounds)
|
| 1553 |
+
var.use_count += 1
|
| 1554 |
+
|
| 1555 |
+
return var
|
| 1556 |
+
|
| 1557 |
+
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
|
| 1558 |
+
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
| 1559 |
+
var = V.kernel.create_cse_var(var_name, bounds)
|
| 1560 |
+
self.varname_map[var_name] = var
|
| 1561 |
+
return var
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
class CodeGen:
|
| 1565 |
+
def __init__(self) -> None:
|
| 1566 |
+
super().__init__()
|
| 1567 |
+
self.exit_stack = contextlib.ExitStack()
|
| 1568 |
+
|
| 1569 |
+
def __enter__(self):
|
| 1570 |
+
self.exit_stack.__enter__()
|
| 1571 |
+
return self
|
| 1572 |
+
|
| 1573 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 1574 |
+
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
| 1575 |
+
|
| 1576 |
+
|
| 1577 |
+
class ScopedDict:
|
| 1578 |
+
def __init__(self, original_dict):
|
| 1579 |
+
self.original_dict = original_dict
|
| 1580 |
+
self.new_items = {}
|
| 1581 |
+
|
| 1582 |
+
def __getitem__(self, key):
|
| 1583 |
+
if key in self.new_items:
|
| 1584 |
+
return self.new_items[key]
|
| 1585 |
+
return self.original_dict[key]
|
| 1586 |
+
|
| 1587 |
+
def __setitem__(self, key, value):
|
| 1588 |
+
self.new_items[key] = value
|
| 1589 |
+
|
| 1590 |
+
def __contains__(self, key):
|
| 1591 |
+
return key in self.new_items or key in self.original_dict
|
| 1592 |
+
|
| 1593 |
+
def get(self, key, default=None):
|
| 1594 |
+
if key in self.new_items:
|
| 1595 |
+
return self.new_items[key]
|
| 1596 |
+
return self.original_dict.get(key, default)
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
class Kernel(CodeGen):
|
| 1600 |
+
newvar_prefix = ""
|
| 1601 |
+
suffix = ""
|
| 1602 |
+
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
| 1603 |
+
# TODO: these look dead, but with all the getattr it's hard to tell...
|
| 1604 |
+
load_format: None = None
|
| 1605 |
+
store_format: None = None
|
| 1606 |
+
|
| 1607 |
+
def __init__(self, args=None, increase_kernel_count=True):
|
| 1608 |
+
super().__init__()
|
| 1609 |
+
if increase_kernel_count:
|
| 1610 |
+
metrics.generated_kernel_count += 1
|
| 1611 |
+
self.args = args or KernelArgs()
|
| 1612 |
+
self.loads = IndentedBuffer()
|
| 1613 |
+
self.compute = IndentedBuffer()
|
| 1614 |
+
self.stores = IndentedBuffer()
|
| 1615 |
+
|
| 1616 |
+
self.num_load = 0
|
| 1617 |
+
self.num_reduction = 0
|
| 1618 |
+
|
| 1619 |
+
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
| 1620 |
+
self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated]
|
| 1621 |
+
self.store_buffer_names = OrderedSet() # type: ignore[var-annotated]
|
| 1622 |
+
self._load_mask = None
|
| 1623 |
+
self._load_other = None
|
| 1624 |
+
# OrderedSet in set_current_node
|
| 1625 |
+
self.current_node = None
|
| 1626 |
+
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
| 1627 |
+
|
| 1628 |
+
self.removed_buffers = OrderedSet() # type: ignore[var-annotated]
|
| 1629 |
+
self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated]
|
| 1630 |
+
|
| 1631 |
+
# key: the buffer to write
|
| 1632 |
+
# value: the buffer to read and whose memory can be reused for
|
| 1633 |
+
# the buffer specified by key
|
| 1634 |
+
self.inplace_update_buffers = {}
|
| 1635 |
+
# Set minimum number of elements processed per thread.
|
| 1636 |
+
self.min_elem_per_thread = 1
|
| 1637 |
+
self.kernel_name = None
|
| 1638 |
+
|
| 1639 |
+
@contextlib.contextmanager
|
| 1640 |
+
def set_current_node(self, node):
|
| 1641 |
+
prior = self.current_node
|
| 1642 |
+
self.current_node = node
|
| 1643 |
+
self.node_to_bounds = node._body.bounds().get_bounds()
|
| 1644 |
+
try:
|
| 1645 |
+
yield
|
| 1646 |
+
finally:
|
| 1647 |
+
self.current_node = prior
|
| 1648 |
+
|
| 1649 |
+
@contextlib.contextmanager
|
| 1650 |
+
def swap_buffers(self, lb, cb=None, sb=None):
|
| 1651 |
+
def scope_cse(cse):
|
| 1652 |
+
new_cse = cse.clone()
|
| 1653 |
+
new_cse.cache = ScopedDict(cse.cache)
|
| 1654 |
+
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
|
| 1655 |
+
new_cse.store_cache = ScopedDict(cse.store_cache)
|
| 1656 |
+
return new_cse
|
| 1657 |
+
|
| 1658 |
+
if cb is None:
|
| 1659 |
+
cb = lb
|
| 1660 |
+
loads = self.loads
|
| 1661 |
+
compute = self.compute
|
| 1662 |
+
stores = self.stores
|
| 1663 |
+
cse = self.cse
|
| 1664 |
+
self.loads = lb
|
| 1665 |
+
self.compute = cb
|
| 1666 |
+
self.stores = sb
|
| 1667 |
+
self.cse = scope_cse(cse)
|
| 1668 |
+
try:
|
| 1669 |
+
yield
|
| 1670 |
+
finally:
|
| 1671 |
+
self.loads = loads
|
| 1672 |
+
self.compute = compute
|
| 1673 |
+
self.stores = stores
|
| 1674 |
+
self.cse = cse
|
| 1675 |
+
|
| 1676 |
+
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
| 1677 |
+
raise NotImplementedError
|
| 1678 |
+
|
| 1679 |
+
def indirect_load(self, name: str, index: sympy.Expr):
|
| 1680 |
+
"""A load the depends on an index we have read"""
|
| 1681 |
+
prior = self.loads
|
| 1682 |
+
try:
|
| 1683 |
+
# put the load in the compute section as it might have deps
|
| 1684 |
+
self.loads = self.compute
|
| 1685 |
+
return self.load(name, index)
|
| 1686 |
+
finally:
|
| 1687 |
+
self.loads = prior
|
| 1688 |
+
|
| 1689 |
+
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
| 1690 |
+
raise NotImplementedError
|
| 1691 |
+
|
| 1692 |
+
def store(
|
| 1693 |
+
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 1694 |
+
) -> None:
|
| 1695 |
+
raise NotImplementedError
|
| 1696 |
+
|
| 1697 |
+
def reduction(
|
| 1698 |
+
self,
|
| 1699 |
+
dtype: torch.dtype,
|
| 1700 |
+
src_dtype: torch.dtype,
|
| 1701 |
+
reduction_type: ReductionType,
|
| 1702 |
+
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
| 1703 |
+
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
| 1704 |
+
raise NotImplementedError
|
| 1705 |
+
|
| 1706 |
+
def scan(
|
| 1707 |
+
self,
|
| 1708 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1709 |
+
combine_fn: Callable[
|
| 1710 |
+
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
|
| 1711 |
+
],
|
| 1712 |
+
values: Tuple[CSEVariable, ...],
|
| 1713 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1714 |
+
raise NotImplementedError
|
| 1715 |
+
|
| 1716 |
+
def sort(
|
| 1717 |
+
self,
|
| 1718 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1719 |
+
values: Tuple[CSEVariable, ...],
|
| 1720 |
+
stable: bool,
|
| 1721 |
+
descending: bool,
|
| 1722 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1723 |
+
raise NotImplementedError
|
| 1724 |
+
|
| 1725 |
+
def var_ranges(self):
|
| 1726 |
+
raise NotImplementedError
|
| 1727 |
+
|
| 1728 |
+
def bucketize(
|
| 1729 |
+
self,
|
| 1730 |
+
values: CSEVariable,
|
| 1731 |
+
offsets_name: str,
|
| 1732 |
+
offsets_size: sympy.Expr,
|
| 1733 |
+
indexing_dtype: torch.dtype,
|
| 1734 |
+
right: bool,
|
| 1735 |
+
) -> CSEVariable:
|
| 1736 |
+
"""
|
| 1737 |
+
See [Note: Inductor bucketize op]
|
| 1738 |
+
"""
|
| 1739 |
+
raise NotImplementedError
|
| 1740 |
+
|
| 1741 |
+
@property
|
| 1742 |
+
def assert_function(self) -> str:
|
| 1743 |
+
raise NotImplementedError
|
| 1744 |
+
|
| 1745 |
+
def indirect_assert(
|
| 1746 |
+
self,
|
| 1747 |
+
var: Union[CSEVariable, str],
|
| 1748 |
+
lower: Optional[str],
|
| 1749 |
+
upper: Optional[str],
|
| 1750 |
+
mask: Optional[Union[CSEVariable, str]] = None,
|
| 1751 |
+
) -> str:
|
| 1752 |
+
if isinstance(var, CSEVariable):
|
| 1753 |
+
var = str(var)
|
| 1754 |
+
assert isinstance(var, str)
|
| 1755 |
+
assert lower is None or isinstance(lower, str)
|
| 1756 |
+
assert upper is None or isinstance(upper, str)
|
| 1757 |
+
if lower and upper:
|
| 1758 |
+
# The conditions need to be in parens because of Python's operator precedence.
|
| 1759 |
+
# It'd be less error-prone to use and/or/not, which is suported by triton
|
| 1760 |
+
cond = f"({lower} <= {var}) & ({var} < {upper})"
|
| 1761 |
+
cond_print = f"{lower} <= {var} < {upper}"
|
| 1762 |
+
elif lower:
|
| 1763 |
+
cond = f"{lower} <= {var}"
|
| 1764 |
+
cond_print = cond
|
| 1765 |
+
else:
|
| 1766 |
+
assert upper
|
| 1767 |
+
cond = f"{var} < {upper}"
|
| 1768 |
+
cond_print = cond
|
| 1769 |
+
|
| 1770 |
+
if mask:
|
| 1771 |
+
cond = f"({cond}) | ~({mask})"
|
| 1772 |
+
|
| 1773 |
+
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
|
| 1774 |
+
|
| 1775 |
+
def check_bounds(
|
| 1776 |
+
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 1777 |
+
):
|
| 1778 |
+
raise NotImplementedError
|
| 1779 |
+
|
| 1780 |
+
def index_to_str(self, index: sympy.Expr) -> str:
|
| 1781 |
+
raise NotImplementedError
|
| 1782 |
+
|
| 1783 |
+
def __enter__(self):
|
| 1784 |
+
# TODO: hoist this to top level
|
| 1785 |
+
class CSEProxy:
|
| 1786 |
+
self.name = "CSEProxy"
|
| 1787 |
+
vr_analysis = ValueRangeAnalysis()
|
| 1788 |
+
|
| 1789 |
+
@staticmethod
|
| 1790 |
+
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
| 1791 |
+
def inner(*args, **kwargs):
|
| 1792 |
+
bounds = CSEProxy._bound_variable(name, *args, **kwargs)
|
| 1793 |
+
|
| 1794 |
+
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
| 1795 |
+
|
| 1796 |
+
def do_cse(v):
|
| 1797 |
+
csevar = V.kernel.cse.generate(
|
| 1798 |
+
V.kernel.compute, v, bounds=bounds
|
| 1799 |
+
)
|
| 1800 |
+
csevar.update_on_args(name, args, kwargs)
|
| 1801 |
+
return csevar
|
| 1802 |
+
|
| 1803 |
+
return pytree.tree_map(do_cse, value)
|
| 1804 |
+
|
| 1805 |
+
return inner
|
| 1806 |
+
|
| 1807 |
+
@staticmethod
|
| 1808 |
+
def _bound_variable(name, *args, **kwargs):
|
| 1809 |
+
"""
|
| 1810 |
+
If the variable comes from an FX node, we forward the bound we have already computed
|
| 1811 |
+
Else, if the variable when codegen'ing another op, we try to compute its bounds
|
| 1812 |
+
"""
|
| 1813 |
+
from ..select_algorithm import TritonTemplateKernel
|
| 1814 |
+
|
| 1815 |
+
if isinstance(V.kernel, TritonTemplateKernel):
|
| 1816 |
+
return ValueRanges.unknown()
|
| 1817 |
+
|
| 1818 |
+
fx_node = V.interpreter.current_node
|
| 1819 |
+
if fx_node.target == name and self.node_to_bounds is not None:
|
| 1820 |
+
assert isinstance(self.node_to_bounds, dict)
|
| 1821 |
+
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
|
| 1822 |
+
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
|
| 1823 |
+
# These create lots of inner strings. We would need to compute the bounds at the ops
|
| 1824 |
+
# We will also likely not get much from computing VRs on these nodes
|
| 1825 |
+
if any(
|
| 1826 |
+
s in fx_node.target
|
| 1827 |
+
for s in ("set_indirect", "reduction", "scan")
|
| 1828 |
+
):
|
| 1829 |
+
return ValueRanges.unknown()
|
| 1830 |
+
|
| 1831 |
+
# We assume that the inputs come from `ops.` and are not strings. If you want to generate
|
| 1832 |
+
# intermediary strings, wrap them in CSE variables with properly initialised bounds.
|
| 1833 |
+
|
| 1834 |
+
# If there is no FX bound but we know how to compute one we do so
|
| 1835 |
+
assert not kwargs
|
| 1836 |
+
|
| 1837 |
+
def arg_to_bound(x):
|
| 1838 |
+
if isinstance(x, CSEVariable):
|
| 1839 |
+
return x.bounds
|
| 1840 |
+
elif isinstance(x, sympy.Expr):
|
| 1841 |
+
return bound_sympy(x)
|
| 1842 |
+
else:
|
| 1843 |
+
return x
|
| 1844 |
+
|
| 1845 |
+
arg_bounds = list(map(arg_to_bound, args))
|
| 1846 |
+
return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
|
| 1847 |
+
else:
|
| 1848 |
+
return ValueRanges.unknown()
|
| 1849 |
+
|
| 1850 |
+
@staticmethod
|
| 1851 |
+
def indirect_indexing(
|
| 1852 |
+
var: CSEVariable,
|
| 1853 |
+
size: Union[sympy.Expr, int],
|
| 1854 |
+
check: bool = True,
|
| 1855 |
+
wrap_neg=True,
|
| 1856 |
+
):
|
| 1857 |
+
if isinstance(size, int):
|
| 1858 |
+
size = sympy.Integer(size)
|
| 1859 |
+
assert isinstance(size, sympy.Expr), size
|
| 1860 |
+
# Skip CSE since this doesn't return an expression
|
| 1861 |
+
|
| 1862 |
+
if var.bounds.lower < 0: # type: ignore[operator]
|
| 1863 |
+
if wrap_neg:
|
| 1864 |
+
stm = ops.add(var, ops.index_expr(size, torch.long))
|
| 1865 |
+
# Mixed negative and non-negative
|
| 1866 |
+
if var.bounds.upper >= 0: # type: ignore[operator]
|
| 1867 |
+
lt = ops.lt(var, 0)
|
| 1868 |
+
stm = ops.where(lt, stm, var)
|
| 1869 |
+
else:
|
| 1870 |
+
stm = var
|
| 1871 |
+
|
| 1872 |
+
# Propagate bounds as we know how to compute them properly
|
| 1873 |
+
new_bounds = ValueRanges.unknown()
|
| 1874 |
+
if var.bounds != ValueRanges.unknown() and isinstance(
|
| 1875 |
+
size, sympy.Number
|
| 1876 |
+
):
|
| 1877 |
+
# Take the negative part of the bound and add size to it
|
| 1878 |
+
# Then take union of that and the positive part
|
| 1879 |
+
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
| 1880 |
+
neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
|
| 1881 |
+
new_bounds = ValueRanges(
|
| 1882 |
+
neg_bounds.lower + size, neg_bounds.upper + size
|
| 1883 |
+
)
|
| 1884 |
+
# We don't have a good way of representing the empty range
|
| 1885 |
+
if var.bounds.upper >= 0: # type: ignore[operator]
|
| 1886 |
+
pos = var.bounds & ValueRanges(0, int_oo)
|
| 1887 |
+
new_bounds = new_bounds | pos
|
| 1888 |
+
|
| 1889 |
+
var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
| 1890 |
+
|
| 1891 |
+
sympy_var = parent_handler.indirect_indexing(var, size, check)
|
| 1892 |
+
if generate_assert(check):
|
| 1893 |
+
assert_lower = not (var.bounds.lower >= 0)
|
| 1894 |
+
# value ranges cannot x < s when x and s are symbols
|
| 1895 |
+
assert_upper = not isinstance(size, sympy.Number) or not (
|
| 1896 |
+
var.bounds.upper < size
|
| 1897 |
+
)
|
| 1898 |
+
self.check_bounds(sympy_var, size, assert_lower, assert_upper)
|
| 1899 |
+
return sympy_var
|
| 1900 |
+
|
| 1901 |
+
@staticmethod
|
| 1902 |
+
def check_bounds(
|
| 1903 |
+
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
| 1904 |
+
):
|
| 1905 |
+
return self.check_bounds(expr, size, lower, upper)
|
| 1906 |
+
|
| 1907 |
+
@staticmethod
|
| 1908 |
+
def load(name: str, index: sympy.Expr) -> CSEVariable:
|
| 1909 |
+
if name in self.cse.invalidated_stores:
|
| 1910 |
+
# A load from an invalidated store requires us to
|
| 1911 |
+
# keep the actual buffer around
|
| 1912 |
+
V.kernel.must_keep_buffers.add(name)
|
| 1913 |
+
if free_symbol_is_type(index, SymT.TMP):
|
| 1914 |
+
return self.indirect_load(name, index)
|
| 1915 |
+
store_cache = self.cse.store_cache
|
| 1916 |
+
if name in store_cache:
|
| 1917 |
+
return store_cache[name]
|
| 1918 |
+
out = self.load(name, index)
|
| 1919 |
+
# count load that is not in the store_cache, and also not in the
|
| 1920 |
+
# cse cache.
|
| 1921 |
+
if out.use_count == 1:
|
| 1922 |
+
self.num_load += 1
|
| 1923 |
+
return out
|
| 1924 |
+
|
| 1925 |
+
@staticmethod
|
| 1926 |
+
def _update_store_cache(name: str, value: CSEVariable):
|
| 1927 |
+
self.cse.store_cache[name] = value
|
| 1928 |
+
if self.current_node and name in V.graph.name_to_buffer:
|
| 1929 |
+
buf = self.current_node.get_output(name)
|
| 1930 |
+
for other_name in buf.get_mutations():
|
| 1931 |
+
self.cse.store_cache[other_name] = value
|
| 1932 |
+
|
| 1933 |
+
@staticmethod
|
| 1934 |
+
def store(
|
| 1935 |
+
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
| 1936 |
+
) -> None:
|
| 1937 |
+
self.store_buffer_names.add(name)
|
| 1938 |
+
if mode is None:
|
| 1939 |
+
CSEProxy._update_store_cache(name, value)
|
| 1940 |
+
if name not in V.graph.removed_buffers:
|
| 1941 |
+
return self.store(name, index, value, mode=mode)
|
| 1942 |
+
else:
|
| 1943 |
+
return None # type: ignore[return-value]
|
| 1944 |
+
|
| 1945 |
+
@staticmethod
|
| 1946 |
+
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
| 1947 |
+
self.store_buffer_names.add(name)
|
| 1948 |
+
CSEProxy._update_store_cache(name, value)
|
| 1949 |
+
|
| 1950 |
+
if name not in V.graph.removed_buffers:
|
| 1951 |
+
return self.store_reduction(name, index, value)
|
| 1952 |
+
|
| 1953 |
+
@staticmethod
|
| 1954 |
+
def reduction(
|
| 1955 |
+
dtype: torch.dtype,
|
| 1956 |
+
src_dtype: torch.dtype,
|
| 1957 |
+
reduction_type: ReductionType,
|
| 1958 |
+
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
| 1959 |
+
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
| 1960 |
+
self.num_reduction += 1
|
| 1961 |
+
return self.reduction(dtype, src_dtype, reduction_type, value)
|
| 1962 |
+
|
| 1963 |
+
@staticmethod
|
| 1964 |
+
def scan(
|
| 1965 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1966 |
+
combine_fn: Callable[
|
| 1967 |
+
[Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
|
| 1968 |
+
Tuple[CSEVariable, ...],
|
| 1969 |
+
],
|
| 1970 |
+
values: Tuple[CSEVariable, ...],
|
| 1971 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1972 |
+
return self.scan(dtypes, combine_fn, values)
|
| 1973 |
+
|
| 1974 |
+
@staticmethod
|
| 1975 |
+
def sort(
|
| 1976 |
+
dtypes: Tuple[torch.dtype, ...],
|
| 1977 |
+
values: Tuple[CSEVariable, ...],
|
| 1978 |
+
stable: bool,
|
| 1979 |
+
descending: bool,
|
| 1980 |
+
) -> Tuple[CSEVariable, ...]:
|
| 1981 |
+
return self.sort(dtypes, values, stable, descending)
|
| 1982 |
+
|
| 1983 |
+
@staticmethod
|
| 1984 |
+
def bucketize(
|
| 1985 |
+
values: CSEVariable,
|
| 1986 |
+
offsets_name: str,
|
| 1987 |
+
offsets_size: sympy.Expr,
|
| 1988 |
+
indexing_dtype: torch.dtype,
|
| 1989 |
+
right: bool,
|
| 1990 |
+
) -> CSEVariable:
|
| 1991 |
+
"""
|
| 1992 |
+
[Note: Inductor bucketize op]
|
| 1993 |
+
|
| 1994 |
+
Given values (tensor) and offsets_name (reference to the name of a 1D
|
| 1995 |
+
tensor), calculate the bucket that each value belongs to.
|
| 1996 |
+
|
| 1997 |
+
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
|
| 1998 |
+
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
|
| 1999 |
+
|
| 2000 |
+
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
|
| 2001 |
+
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
|
| 2002 |
+
|
| 2003 |
+
Offsets must be non-decreasing or the result is undefined.
|
| 2004 |
+
"""
|
| 2005 |
+
return self.bucketize(
|
| 2006 |
+
values, offsets_name, offsets_size, indexing_dtype, right
|
| 2007 |
+
)
|
| 2008 |
+
|
| 2009 |
+
# Use mypy to check protocol implemented correctly
|
| 2010 |
+
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
| 2011 |
+
return h
|
| 2012 |
+
|
| 2013 |
+
super().__enter__()
|
| 2014 |
+
assert self.overrides
|
| 2015 |
+
parent_handler = self.overrides(V.get_ops_handler())
|
| 2016 |
+
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
|
| 2017 |
+
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
| 2018 |
+
return self
|
| 2019 |
+
|
| 2020 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 2021 |
+
"""
|
| 2022 |
+
Note that V.graph.scheduler can be None when codegening triton template
|
| 2023 |
+
kernels.
|
| 2024 |
+
"""
|
| 2025 |
+
if V.graph.scheduler:
|
| 2026 |
+
V.graph.scheduler.remove_kernel_local_buffers()
|
| 2027 |
+
super().__exit__(exc_type, exc_val, exc_tb)
|
| 2028 |
+
|
| 2029 |
+
def rename_indexing(self, index) -> sympy.Expr:
|
| 2030 |
+
# adds the necessary kernel args for index expressions
|
| 2031 |
+
# and renames variables in index expressions to kernel arg names
|
| 2032 |
+
if isinstance(index, (list, tuple)):
|
| 2033 |
+
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
| 2034 |
+
index = V.graph.sizevars.simplify(index)
|
| 2035 |
+
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
| 2036 |
+
replacements = {
|
| 2037 |
+
x: self.args.size(x)
|
| 2038 |
+
for x in sorted_symbols
|
| 2039 |
+
if symbol_is_type(
|
| 2040 |
+
x,
|
| 2041 |
+
(
|
| 2042 |
+
SymT.UNBACKED_INT,
|
| 2043 |
+
SymT.SIZE,
|
| 2044 |
+
SymT.PRECOMPUTED_SIZE,
|
| 2045 |
+
),
|
| 2046 |
+
)
|
| 2047 |
+
}
|
| 2048 |
+
return sympy_subs(index, replacements)
|
| 2049 |
+
|
| 2050 |
+
def create_cse_var(self, *args, **kwargs):
|
| 2051 |
+
return CSEVariable(*args, **kwargs)
|
| 2052 |
+
|
| 2053 |
+
|
| 2054 |
+
@dataclasses.dataclass
|
| 2055 |
+
class OptimizationContext:
|
| 2056 |
+
key: ClassVar[str] = "opt_ctx"
|
| 2057 |
+
|
| 2058 |
+
dtype: Optional[torch.dtype] = None
|
| 2059 |
+
ops_name: str = ""
|
| 2060 |
+
|
| 2061 |
+
|
| 2062 |
+
@functools.lru_cache(None)
|
| 2063 |
+
def jinja2_env():
|
| 2064 |
+
try:
|
| 2065 |
+
import jinja2
|
| 2066 |
+
|
| 2067 |
+
return jinja2.Environment(
|
| 2068 |
+
undefined=jinja2.StrictUndefined,
|
| 2069 |
+
)
|
| 2070 |
+
except ImportError:
|
| 2071 |
+
return None
|
| 2072 |
+
|
| 2073 |
+
|
| 2074 |
+
class KernelTemplate:
|
| 2075 |
+
"""
|
| 2076 |
+
Base class for defining kernel templates.
|
| 2077 |
+
|
| 2078 |
+
Children classes: TritonTemplate, CUDATemplate
|
| 2079 |
+
"""
|
| 2080 |
+
|
| 2081 |
+
@staticmethod
|
| 2082 |
+
def indent_except_first(source: str, num_indents: int, indents_spacing=4):
|
| 2083 |
+
lines = source.splitlines(True)
|
| 2084 |
+
if len(lines) > 1:
|
| 2085 |
+
lines[1:] = [
|
| 2086 |
+
(" " * indents_spacing * num_indents) + line for line in lines[1:]
|
| 2087 |
+
]
|
| 2088 |
+
return "".join(lines)
|
| 2089 |
+
|
| 2090 |
+
@staticmethod
|
| 2091 |
+
def _template_from_string(source):
|
| 2092 |
+
env = jinja2_env()
|
| 2093 |
+
if env is not None:
|
| 2094 |
+
env.filters["indent_except_first"] = KernelTemplate.indent_except_first
|
| 2095 |
+
from jinja2 import TemplateSyntaxError
|
| 2096 |
+
|
| 2097 |
+
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
| 2098 |
+
def __init__(self, original_error):
|
| 2099 |
+
super().__init__(
|
| 2100 |
+
original_error.message,
|
| 2101 |
+
original_error.lineno,
|
| 2102 |
+
original_error.name,
|
| 2103 |
+
original_error.filename,
|
| 2104 |
+
)
|
| 2105 |
+
self.original_error = original_error
|
| 2106 |
+
|
| 2107 |
+
def __str__(self):
|
| 2108 |
+
error_info = f"Error in template at line {self.lineno}\n"
|
| 2109 |
+
error_info += f"Error message: {self.message}\n"
|
| 2110 |
+
if hasattr(self.original_error, "source"):
|
| 2111 |
+
lines = self.original_error.source.split("\n")
|
| 2112 |
+
error_info += "Context:\n"
|
| 2113 |
+
start = max(0, self.lineno - 2)
|
| 2114 |
+
end = min(len(lines), self.lineno + 2)
|
| 2115 |
+
for i in range(start, end):
|
| 2116 |
+
if i == self.lineno - 1:
|
| 2117 |
+
error_info += f"{i+1}: --> {lines[i]}\n"
|
| 2118 |
+
if hasattr(self.original_error, "column"):
|
| 2119 |
+
error_info += (
|
| 2120 |
+
" "
|
| 2121 |
+
+ " " * (self.original_error.column - 1)
|
| 2122 |
+
+ "^\n"
|
| 2123 |
+
)
|
| 2124 |
+
else:
|
| 2125 |
+
error_info += f"{i+1}: {lines[i]}\n"
|
| 2126 |
+
return error_info
|
| 2127 |
+
|
| 2128 |
+
try:
|
| 2129 |
+
return env.from_string(source)
|
| 2130 |
+
except TemplateSyntaxError as e:
|
| 2131 |
+
raise DetailedTemplateSyntaxError(e) from e
|
| 2132 |
+
|
| 2133 |
+
return None
|
| 2134 |
+
|
| 2135 |
+
@staticmethod
|
| 2136 |
+
def _fake_get_dtype(fake_out):
|
| 2137 |
+
_get_dtype_real = V.graph.get_dtype
|
| 2138 |
+
|
| 2139 |
+
def get_dtype(name):
|
| 2140 |
+
if name == fake_out.get_name():
|
| 2141 |
+
return fake_out.get_dtype()
|
| 2142 |
+
return _get_dtype_real(name)
|
| 2143 |
+
|
| 2144 |
+
return get_dtype
|
| 2145 |
+
|
| 2146 |
+
def __init__(self, name: str):
|
| 2147 |
+
self.name = name
|
| 2148 |
+
|
| 2149 |
+
def maybe_append_choice(self, choices, **kwargs):
|
| 2150 |
+
"""
|
| 2151 |
+
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
| 2152 |
+
|
| 2153 |
+
choices: A list of ChoiceCallers.
|
| 2154 |
+
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
|
| 2155 |
+
"""
|
| 2156 |
+
|
| 2157 |
+
try:
|
| 2158 |
+
choices.append(self.generate(**kwargs))
|
| 2159 |
+
except NotImplementedError as e:
|
| 2160 |
+
pass
|
| 2161 |
+
|
| 2162 |
+
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
|
| 2163 |
+
"""
|
| 2164 |
+
Generates a ChoiceCaller instance from the given arguments.
|
| 2165 |
+
"""
|
| 2166 |
+
|
| 2167 |
+
raise NotImplementedError
|
.venv/Lib/site-packages/torch/_inductor/codegen/cpp.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py
ADDED
|
@@ -0,0 +1,1043 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from typing import Any, Callable, cast, List, Optional, Set, Union
|
| 7 |
+
from unittest.mock import patch
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils
|
| 11 |
+
|
| 12 |
+
from ..._dynamo.utils import counters
|
| 13 |
+
from .. import config, ir, lowering as L
|
| 14 |
+
from ..kernel.mm_common import mm_args
|
| 15 |
+
from ..select_algorithm import DataProcessorTemplateWrapper
|
| 16 |
+
from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
|
| 17 |
+
from ..virtualized import ops, V
|
| 18 |
+
from .cpp import get_export_declaration
|
| 19 |
+
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
|
| 20 |
+
from .cpp_template import CppTemplate
|
| 21 |
+
from .cpp_template_kernel import CppTemplateKernel
|
| 22 |
+
from .cpp_utils import (
|
| 23 |
+
create_epilogue_with_attr,
|
| 24 |
+
DTYPE_TO_CPP,
|
| 25 |
+
GemmBlocking,
|
| 26 |
+
get_gemm_template_output_and_compute_dtype,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
log = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
GEMM_TEMPLATE = r"""
|
| 33 |
+
{{template.header().getvalue()}}
|
| 34 |
+
|
| 35 |
+
{{micro_gemm.codegen_define(kernel)}}
|
| 36 |
+
|
| 37 |
+
{%- if x_scale is not none %}
|
| 38 |
+
{%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
|
| 39 |
+
{%- else %}
|
| 40 |
+
{%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
|
| 43 |
+
extern "C" {{export_declaration}}
|
| 44 |
+
{{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
|
| 45 |
+
{
|
| 46 |
+
{{kernel.maybe_codegen_profile()}}
|
| 47 |
+
constexpr int64_t num_threads = {{num_threads}};
|
| 48 |
+
constexpr int64_t N = {{N}};
|
| 49 |
+
constexpr int64_t K = {{K}};
|
| 50 |
+
constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
|
| 51 |
+
constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
|
| 52 |
+
constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
|
| 53 |
+
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
|
| 54 |
+
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
|
| 55 |
+
|
| 56 |
+
{%- if is_dynamic_M %}
|
| 57 |
+
const int64_t M = {{kernel.size(GemmOut, 0)}};
|
| 58 |
+
const int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
| 59 |
+
{%- if num_threads > 1 %}
|
| 60 |
+
int64_t Mt_blocks, Nt_blocks, Kt_blocks;
|
| 61 |
+
mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
|
| 62 |
+
{%- else %}
|
| 63 |
+
const auto Mt_blocks = Mr_blocks;
|
| 64 |
+
const auto Nt_blocks = Nr_blocks;
|
| 65 |
+
const auto Kt_blocks = Kr_blocks;
|
| 66 |
+
{%- endif %}
|
| 67 |
+
int64_t Mc_blocks, Nc_blocks, Kc_blocks;
|
| 68 |
+
uint32_t L1_cache_size = {{L1_cache_size}};
|
| 69 |
+
uint32_t L2_cache_size = {{L2_cache_size}};
|
| 70 |
+
mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
|
| 71 |
+
num_threads,
|
| 72 |
+
M,
|
| 73 |
+
N,
|
| 74 |
+
K,
|
| 75 |
+
Mr,
|
| 76 |
+
Nr,
|
| 77 |
+
Kr,
|
| 78 |
+
Mt_blocks,
|
| 79 |
+
Nt_blocks,
|
| 80 |
+
Kt_blocks,
|
| 81 |
+
Mc_blocks,
|
| 82 |
+
Nc_blocks,
|
| 83 |
+
Kc_blocks,
|
| 84 |
+
L1_cache_size,
|
| 85 |
+
L2_cache_size
|
| 86 |
+
);
|
| 87 |
+
const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
| 88 |
+
const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
| 89 |
+
const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
| 90 |
+
{%- else %}
|
| 91 |
+
constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
|
| 92 |
+
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
|
| 93 |
+
constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}};
|
| 94 |
+
constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
|
| 95 |
+
constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
|
| 96 |
+
constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
|
| 97 |
+
constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}};
|
| 98 |
+
constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
|
| 99 |
+
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
|
| 100 |
+
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
|
| 101 |
+
constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
| 102 |
+
{%- endif %}
|
| 103 |
+
|
| 104 |
+
// make sure all partitions are assigned
|
| 105 |
+
{{kernel.assert_function}}(
|
| 106 |
+
Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
|
| 107 |
+
"Not all partitions are assigned."
|
| 108 |
+
);
|
| 109 |
+
|
| 110 |
+
{%- if maybe_k_slicing %}
|
| 111 |
+
std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
|
| 112 |
+
if (num_k_slices > 1) {
|
| 113 |
+
local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
|
| 114 |
+
}
|
| 115 |
+
{%- endif %}
|
| 116 |
+
|
| 117 |
+
{%- if num_threads > 1 %}
|
| 118 |
+
#pragma omp parallel num_threads({{num_threads}})
|
| 119 |
+
{
|
| 120 |
+
const int tid = omp_get_thread_num();
|
| 121 |
+
int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
|
| 122 |
+
mm_get_thread_blocks(
|
| 123 |
+
tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
|
| 124 |
+
m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
|
| 125 |
+
{%- if maybe_k_slicing %}
|
| 126 |
+
const int64_t k_group_id = tid / num_k_slices;
|
| 127 |
+
const int64_t k_slice_id = tid % num_k_slices;
|
| 128 |
+
{%- endif %}
|
| 129 |
+
{%- else %}
|
| 130 |
+
{
|
| 131 |
+
const int tid = 0;
|
| 132 |
+
const int64_t m_block_start = 0;
|
| 133 |
+
const int64_t m_block_end = Mr_blocks;
|
| 134 |
+
const int64_t n_block_start = 0;
|
| 135 |
+
const int64_t n_block_end = Nr_blocks;
|
| 136 |
+
const int64_t k_block_start = 0;
|
| 137 |
+
const int64_t k_block_end = Kr_blocks;
|
| 138 |
+
{%- endif %}
|
| 139 |
+
{{ micro_gemm.codegen_init(kernel) }}
|
| 140 |
+
{%- if use_local_acc %}
|
| 141 |
+
{%- set acc_buf_name = "local_acc_buf" %}
|
| 142 |
+
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
| 143 |
+
{%- endif %}
|
| 144 |
+
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
| 145 |
+
const int64_t m_start = mc * Mr;
|
| 146 |
+
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
| 147 |
+
const int64_t m_size = m_end - m_start;
|
| 148 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 149 |
+
const int64_t n_start = nc * Nr;
|
| 150 |
+
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
| 151 |
+
const int64_t n_size = n_end - n_start;
|
| 152 |
+
// NB: assume we pad N, nc_block_end won't exceed padded N here.
|
| 153 |
+
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
|
| 154 |
+
{%- if use_local_acc %}
|
| 155 |
+
{%- set acc = kernel.local_buffers[acc_buf_name] %}
|
| 156 |
+
{{ kernel.reinit_buffer_if_null(acc_buf_name) }}
|
| 157 |
+
{%- else %}
|
| 158 |
+
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
|
| 159 |
+
{%- endif %}
|
| 160 |
+
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
|
| 161 |
+
int64_t k_start = kc * Kr;
|
| 162 |
+
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
|
| 163 |
+
{%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
|
| 164 |
+
for (int64_t nci = nc; nci < nc_block_end; nci++) {
|
| 165 |
+
{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
|
| 166 |
+
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
|
| 167 |
+
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
|
| 168 |
+
if (kc == k_block_start) {
|
| 169 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
|
| 170 |
+
} else {
|
| 171 |
+
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
{%- if maybe_k_slicing %}
|
| 176 |
+
if (num_k_slices > 1) {
|
| 177 |
+
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
| 178 |
+
local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
|
| 179 |
+
} else
|
| 180 |
+
{%- endif %}
|
| 181 |
+
{
|
| 182 |
+
{%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
|
| 183 |
+
{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
|
| 184 |
+
{{ kernel.store_output(
|
| 185 |
+
tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
| 186 |
+
)|indent(20, false)
|
| 187 |
+
}}
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
}
|
| 191 |
+
{%- if maybe_k_slicing %}
|
| 192 |
+
if (num_k_slices > 1) {
|
| 193 |
+
#pragma omp barrier
|
| 194 |
+
for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
|
| 195 |
+
// We slice M-dim and each thread in the k-slicing group works on a slice
|
| 196 |
+
const int64_t m_start_unsliced = mc * Mr;
|
| 197 |
+
const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
|
| 198 |
+
const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
|
| 199 |
+
const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
|
| 200 |
+
const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
|
| 201 |
+
const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
|
| 202 |
+
const int64_t m_size = m_end - m_start;
|
| 203 |
+
const int64_t m_offset = m_start - m_start_unsliced;
|
| 204 |
+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
| 205 |
+
const int64_t n_start = nc * Nr;
|
| 206 |
+
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
|
| 207 |
+
const int64_t n_size = n_end - n_start;
|
| 208 |
+
const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
|
| 209 |
+
auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
|
| 210 |
+
for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
|
| 211 |
+
auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
|
| 212 |
+
for (int64_t m = m_offset; m < m_offset + m_size; m++) {
|
| 213 |
+
#pragma omp simd
|
| 214 |
+
for (int64_t n = 0; n < n_size; n++) {
|
| 215 |
+
{{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
{%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
|
| 220 |
+
{{ kernel.store_output(
|
| 221 |
+
tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
| 222 |
+
)|indent(20, false)
|
| 223 |
+
}}
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
{%- endif %}
|
| 228 |
+
{{ micro_gemm.codegen_finalize(kernel) }}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_padded_n(n, block_n):
|
| 235 |
+
return (n + block_n - 1) // block_n * block_n
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class CppPackedGemmTemplate(CppTemplate):
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
input_nodes,
|
| 242 |
+
layout: ir.Layout,
|
| 243 |
+
num_threads: int,
|
| 244 |
+
register_blocking: GemmBlocking,
|
| 245 |
+
beta=1,
|
| 246 |
+
alpha=1,
|
| 247 |
+
has_bias=False,
|
| 248 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 249 |
+
) -> None:
|
| 250 |
+
assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
|
| 251 |
+
super().__init__(
|
| 252 |
+
"packed_gemm",
|
| 253 |
+
input_nodes,
|
| 254 |
+
layout,
|
| 255 |
+
num_threads,
|
| 256 |
+
epilogue_creator=epilogue_creator,
|
| 257 |
+
)
|
| 258 |
+
self.beta = beta
|
| 259 |
+
self.alpha = alpha
|
| 260 |
+
self.has_bias = has_bias
|
| 261 |
+
self.register_blocking = register_blocking
|
| 262 |
+
m, n = layout.size
|
| 263 |
+
_, k = input_nodes[0].get_size()
|
| 264 |
+
self.m, self.n, self.k = m, n, k
|
| 265 |
+
self.padded_n = get_padded_n(n, self.register_blocking.block_n)
|
| 266 |
+
self.is_dynamic_M = has_free_symbols((m,))
|
| 267 |
+
|
| 268 |
+
@cache_on_self
|
| 269 |
+
def thread_blocking(self) -> GemmBlocking:
|
| 270 |
+
"""
|
| 271 |
+
NOTE [Thread blocking in Cpp GEMM]
|
| 272 |
+
We use simple heuristics to decide the thread blocking:
|
| 273 |
+
1. Make sure all threads are occupied as much as possible.
|
| 274 |
+
2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
|
| 275 |
+
3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
|
| 276 |
+
TODO(jgong5): allow tuning various blocking options
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
@lru_cache(maxsize=100)
|
| 280 |
+
def get_factors(number):
|
| 281 |
+
factors = []
|
| 282 |
+
for i in range(int(number**0.5), 0, -1):
|
| 283 |
+
if number % i == 0:
|
| 284 |
+
factors.append(number // i)
|
| 285 |
+
factors.append(i)
|
| 286 |
+
return factors
|
| 287 |
+
|
| 288 |
+
def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
|
| 289 |
+
thread_block_k = math.ceil(k_blocks / k_factor)
|
| 290 |
+
thread_block_n = math.ceil(n_blocks / n_factor)
|
| 291 |
+
thread_block_m = math.ceil(m_blocks / m_factor)
|
| 292 |
+
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
|
| 293 |
+
|
| 294 |
+
assert (
|
| 295 |
+
not self.is_dynamic_M
|
| 296 |
+
), "Unable to determine thread blocking for dynamic M."
|
| 297 |
+
register_blocking = self.register_blocking
|
| 298 |
+
m_blocks = math.ceil(self.m / register_blocking.block_m)
|
| 299 |
+
n_blocks = math.ceil(self.n / register_blocking.block_n)
|
| 300 |
+
k_blocks = math.ceil(self.k / register_blocking.block_k)
|
| 301 |
+
factors = get_factors(self.num_threads)
|
| 302 |
+
assert len(factors) > 0
|
| 303 |
+
|
| 304 |
+
if config.cpp.gemm_thread_factors is not None:
|
| 305 |
+
factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
|
| 306 |
+
assert len(factors) == 3
|
| 307 |
+
assert math.prod(factors) == self.num_threads
|
| 308 |
+
return get_blocking(
|
| 309 |
+
factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# we favor square-sized thread blocks for good data reuse
|
| 313 |
+
def get_better_blocking(blocking, best_blocking):
|
| 314 |
+
if best_blocking is None:
|
| 315 |
+
best_blocking = blocking
|
| 316 |
+
else:
|
| 317 |
+
block_m_size = blocking.block_m * register_blocking.block_m
|
| 318 |
+
block_n_size = blocking.block_n * register_blocking.block_n
|
| 319 |
+
best_block_m_size = best_blocking.block_m * register_blocking.block_m
|
| 320 |
+
best_block_n_size = best_blocking.block_n * register_blocking.block_n
|
| 321 |
+
if blocking.block_k > best_blocking.block_k:
|
| 322 |
+
best_blocking = blocking
|
| 323 |
+
elif (
|
| 324 |
+
blocking.block_k == best_blocking.block_k
|
| 325 |
+
and block_m_size + block_n_size
|
| 326 |
+
< best_block_m_size + best_block_n_size
|
| 327 |
+
):
|
| 328 |
+
best_blocking = blocking
|
| 329 |
+
return best_blocking
|
| 330 |
+
|
| 331 |
+
best_blocking = None
|
| 332 |
+
# check if we can have a thread-blocking to occupy all threads without k-slicing
|
| 333 |
+
for n_factor in factors:
|
| 334 |
+
m_factor = self.num_threads // n_factor
|
| 335 |
+
if n_blocks >= n_factor and m_blocks >= m_factor:
|
| 336 |
+
blocking = get_blocking(
|
| 337 |
+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
|
| 338 |
+
)
|
| 339 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 340 |
+
|
| 341 |
+
if best_blocking is None:
|
| 342 |
+
for k_factor in factors:
|
| 343 |
+
if k_blocks >= k_factor and (
|
| 344 |
+
config.cpp.gemm_max_k_slices == 0
|
| 345 |
+
or k_factor <= config.cpp.gemm_max_k_slices
|
| 346 |
+
):
|
| 347 |
+
n_factors = get_factors(self.num_threads // k_factor)
|
| 348 |
+
for n_factor in n_factors:
|
| 349 |
+
m_factor = (self.num_threads // k_factor) // n_factor
|
| 350 |
+
if n_blocks >= n_factor and m_blocks >= m_factor:
|
| 351 |
+
blocking = get_blocking(
|
| 352 |
+
m_factor,
|
| 353 |
+
n_factor,
|
| 354 |
+
k_factor,
|
| 355 |
+
m_blocks,
|
| 356 |
+
n_blocks,
|
| 357 |
+
k_blocks,
|
| 358 |
+
)
|
| 359 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 360 |
+
|
| 361 |
+
if best_blocking is None:
|
| 362 |
+
for n_factor in factors:
|
| 363 |
+
m_factor = self.num_threads // n_factor
|
| 364 |
+
if n_blocks >= n_factor or m_blocks >= m_factor:
|
| 365 |
+
blocking = get_blocking(
|
| 366 |
+
m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
|
| 367 |
+
)
|
| 368 |
+
best_blocking = get_better_blocking(blocking, best_blocking)
|
| 369 |
+
|
| 370 |
+
assert best_blocking is not None
|
| 371 |
+
return best_blocking
|
| 372 |
+
|
| 373 |
+
@cache_on_self
|
| 374 |
+
def cache_blocking(self) -> GemmBlocking:
|
| 375 |
+
def get_cache_blocking(register_blocking, thread_blocking):
|
| 376 |
+
Mr = register_blocking.block_m
|
| 377 |
+
Nr = register_blocking.block_n
|
| 378 |
+
Kr = register_blocking.block_k
|
| 379 |
+
|
| 380 |
+
Mt_blocks = thread_blocking.block_m
|
| 381 |
+
Nt_blocks = thread_blocking.block_n
|
| 382 |
+
Kt_blocks = thread_blocking.block_k
|
| 383 |
+
|
| 384 |
+
if config.cpp.gemm_cache_blocking is not None:
|
| 385 |
+
blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
|
| 386 |
+
assert len(blockings) == 3
|
| 387 |
+
Mc_blocks, Nc_blocks, Kc_blocks = blockings
|
| 388 |
+
return (
|
| 389 |
+
min(Mc_blocks, Mt_blocks),
|
| 390 |
+
min(Nc_blocks, Nt_blocks),
|
| 391 |
+
min(Kc_blocks, Kt_blocks),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# The ratios below are empirically determined to decide
|
| 395 |
+
# the effective sizes of L1 and L2.
|
| 396 |
+
# TODO: tune the factor here
|
| 397 |
+
L1_limit_factor = 0.8
|
| 398 |
+
L2_limit_factor = 0.5
|
| 399 |
+
|
| 400 |
+
L1_cache_size = (
|
| 401 |
+
torch._C._cpu._L1d_cache_size()
|
| 402 |
+
) # per core cache size in Bytes
|
| 403 |
+
assert (
|
| 404 |
+
L1_cache_size > 0
|
| 405 |
+
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 406 |
+
L1 = L1_cache_size * L1_limit_factor
|
| 407 |
+
|
| 408 |
+
L2_cache_size = (
|
| 409 |
+
torch._C._cpu._L2_cache_size()
|
| 410 |
+
) # per core cache size in Bytes
|
| 411 |
+
assert (
|
| 412 |
+
L2_cache_size > 0
|
| 413 |
+
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 414 |
+
L2 = L2_cache_size * L2_limit_factor
|
| 415 |
+
|
| 416 |
+
def get_num_byte(dtype):
|
| 417 |
+
return torch.tensor([], dtype=dtype).element_size()
|
| 418 |
+
|
| 419 |
+
num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
|
| 420 |
+
num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())
|
| 421 |
+
|
| 422 |
+
# NOTE [CPP GEMM Cache Blocking Algorithm]
|
| 423 |
+
# Our overall strategy is to
|
| 424 |
+
# 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
|
| 425 |
+
# Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
|
| 426 |
+
# decide Kc. We want to make Mc large enough to better reuse B.
|
| 427 |
+
# 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
|
| 428 |
+
# along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
|
| 429 |
+
|
| 430 |
+
# Step 1: Decide Kc assuming B block is L1-reside.
|
| 431 |
+
size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
|
| 432 |
+
Kc_blocks = Kt_blocks
|
| 433 |
+
if size_cache_B > L1:
|
| 434 |
+
Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
|
| 435 |
+
|
| 436 |
+
# Step 2: Decide Mc assuming A block is L2-reside.
|
| 437 |
+
min_Mc_ratio = 2 # TODO(jgong5): something to tune?
|
| 438 |
+
min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
|
| 439 |
+
assert min_Mc_blocks >= 1
|
| 440 |
+
Kt_bytes = Kt_blocks * Kr * num_byte_A
|
| 441 |
+
if min_Mc_blocks * Mr * Kt_bytes < L2:
|
| 442 |
+
# Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
|
| 443 |
+
# when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
|
| 444 |
+
# to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
|
| 445 |
+
# in L1.
|
| 446 |
+
Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
|
| 447 |
+
Nc_blocks = 1
|
| 448 |
+
else:
|
| 449 |
+
# Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
|
| 450 |
+
# A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
|
| 451 |
+
Mc_blocks = Mt_blocks
|
| 452 |
+
Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
|
| 453 |
+
Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32
|
| 454 |
+
Kc_bytes = Kc_blocks * Kr * num_byte_A
|
| 455 |
+
if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
|
| 456 |
+
# The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
|
| 457 |
+
# assuming Mc == Nc for good data reuse.
|
| 458 |
+
M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
|
| 459 |
+
if M_max < Mc_blocks * Mr:
|
| 460 |
+
Mc_blocks = math.floor(M_max / Mr)
|
| 461 |
+
Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
|
| 462 |
+
|
| 463 |
+
return Mc_blocks, Nc_blocks, Kc_blocks
|
| 464 |
+
|
| 465 |
+
assert (
|
| 466 |
+
not self.is_dynamic_M
|
| 467 |
+
), "Unable to determine cache blocking for dynamic M."
|
| 468 |
+
register_blocking = self.register_blocking
|
| 469 |
+
thread_blocking = self.thread_blocking()
|
| 470 |
+
|
| 471 |
+
return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
|
| 472 |
+
|
| 473 |
+
def log_blockings(self):
|
| 474 |
+
log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004
|
| 475 |
+
if self.is_dynamic_M:
|
| 476 |
+
# thread and cache blockings are determined at runtime for dynamic shapes
|
| 477 |
+
return
|
| 478 |
+
log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004
|
| 479 |
+
thread_blocking = self.thread_blocking()
|
| 480 |
+
log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004
|
| 481 |
+
|
| 482 |
+
def get_occupancy():
|
| 483 |
+
m_blocks = math.ceil(self.m / self.register_blocking.block_m)
|
| 484 |
+
n_blocks = math.ceil(self.n / self.register_blocking.block_n)
|
| 485 |
+
k_blocks = math.ceil(self.k / self.register_blocking.block_k)
|
| 486 |
+
m = math.ceil(m_blocks / thread_blocking.block_m)
|
| 487 |
+
n = math.ceil(n_blocks / thread_blocking.block_n)
|
| 488 |
+
k = math.ceil(k_blocks / thread_blocking.block_k)
|
| 489 |
+
return (m, n, k)
|
| 490 |
+
|
| 491 |
+
log.debug(
|
| 492 |
+
f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
def maybe_k_slicing(self):
|
| 496 |
+
if self.num_threads == 1:
|
| 497 |
+
return False
|
| 498 |
+
if self.is_dynamic_M:
|
| 499 |
+
# TODO(jgong5): perhaps use size hint to decide?
|
| 500 |
+
return True
|
| 501 |
+
register_blocking = self.register_blocking
|
| 502 |
+
k_blocks = math.ceil(self.k / register_blocking.block_k)
|
| 503 |
+
thread_blocking = self.thread_blocking()
|
| 504 |
+
return k_blocks > thread_blocking.block_k
|
| 505 |
+
|
| 506 |
+
@staticmethod
|
| 507 |
+
def add_choices(
|
| 508 |
+
choices,
|
| 509 |
+
layout,
|
| 510 |
+
input_nodes,
|
| 511 |
+
beta=1,
|
| 512 |
+
alpha=1,
|
| 513 |
+
has_bias=False,
|
| 514 |
+
trans_w=False,
|
| 515 |
+
input_indices=None,
|
| 516 |
+
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
| 517 |
+
):
|
| 518 |
+
if input_indices is None:
|
| 519 |
+
input_indices = list(range(len(input_nodes)))
|
| 520 |
+
|
| 521 |
+
def reorder_and_filter(inputs, layout_or_out):
|
| 522 |
+
if has_bias:
|
| 523 |
+
assert len(input_indices) >= 3
|
| 524 |
+
# Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
|
| 525 |
+
inp_idx = input_indices[0]
|
| 526 |
+
x_idx = input_indices[1]
|
| 527 |
+
w_idx = input_indices[2]
|
| 528 |
+
return [
|
| 529 |
+
inputs[x_idx],
|
| 530 |
+
inputs[w_idx],
|
| 531 |
+
inputs[inp_idx],
|
| 532 |
+
*[inputs[idx] for idx in input_indices[3:]],
|
| 533 |
+
], layout_or_out
|
| 534 |
+
else:
|
| 535 |
+
assert len(input_indices) >= 2
|
| 536 |
+
return [inputs[idx] for idx in input_indices], layout_or_out
|
| 537 |
+
|
| 538 |
+
def maybe_to_dense(inputs, layout_or_out):
|
| 539 |
+
new_inputs = list(inputs)
|
| 540 |
+
if isinstance(inputs[1], torch.Tensor):
|
| 541 |
+
W = inputs[1]
|
| 542 |
+
new_inputs[1] = W.to_dense() if W.is_mkldnn else W
|
| 543 |
+
return new_inputs, layout_or_out
|
| 544 |
+
|
| 545 |
+
def normalize_shapes(inputs, layout_or_out):
|
| 546 |
+
if not trans_w:
|
| 547 |
+
return inputs, layout_or_out
|
| 548 |
+
new_inputs = list(inputs)
|
| 549 |
+
X = inputs[0]
|
| 550 |
+
W = inputs[1]
|
| 551 |
+
B = inputs[2] if has_bias else None
|
| 552 |
+
if isinstance(W, ir.IRNode):
|
| 553 |
+
if trans_w:
|
| 554 |
+
if not isinstance(W, ir.TensorBox):
|
| 555 |
+
W = ir.TensorBox(W)
|
| 556 |
+
W = L.permute(W, [1, 0])
|
| 557 |
+
else:
|
| 558 |
+
if trans_w:
|
| 559 |
+
assert isinstance(W, torch.Tensor)
|
| 560 |
+
W = W.transpose(0, 1)
|
| 561 |
+
if B is not None:
|
| 562 |
+
if isinstance(B, ir.IRNode):
|
| 563 |
+
if not isinstance(B, ir.TensorBox):
|
| 564 |
+
B = ir.TensorBox(B)
|
| 565 |
+
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
|
| 566 |
+
else:
|
| 567 |
+
assert isinstance(B, torch.Tensor)
|
| 568 |
+
B = B.expand(X.shape[0], B.shape[-1])
|
| 569 |
+
new_inputs[1] = W
|
| 570 |
+
if B is not None:
|
| 571 |
+
new_inputs[2] = B
|
| 572 |
+
return new_inputs, layout_or_out
|
| 573 |
+
|
| 574 |
+
# TODO(jgong5): decide proper number of threads per problem size
|
| 575 |
+
num_threads = parallel_num_threads()
|
| 576 |
+
new_inputs, _ = normalize_shapes(
|
| 577 |
+
*maybe_to_dense(*reorder_and_filter(input_nodes, layout))
|
| 578 |
+
)
|
| 579 |
+
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
|
| 580 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 581 |
+
new_inputs[0].get_dtype()
|
| 582 |
+
)
|
| 583 |
+
micro_gemm = create_micro_gemm(
|
| 584 |
+
"micro_gemm",
|
| 585 |
+
m,
|
| 586 |
+
n,
|
| 587 |
+
k,
|
| 588 |
+
input_dtype=new_inputs[0].get_dtype(),
|
| 589 |
+
input2_dtype=new_inputs[1].get_dtype(),
|
| 590 |
+
output_dtype=output_dtype,
|
| 591 |
+
compute_dtype=compute_dtype,
|
| 592 |
+
alpha=alpha,
|
| 593 |
+
num_threads=num_threads,
|
| 594 |
+
)
|
| 595 |
+
assert micro_gemm is not None
|
| 596 |
+
_, block_n, _ = micro_gemm.register_blocking
|
| 597 |
+
padded_n = get_padded_n(n, block_n)
|
| 598 |
+
|
| 599 |
+
def pack_weight(inputs, layout_or_out):
|
| 600 |
+
W = inputs[1]
|
| 601 |
+
new_inputs = list(inputs)
|
| 602 |
+
blocked_w: Union[ir.IRNode, torch.Tensor] = W
|
| 603 |
+
if isinstance(W, ir.IRNode):
|
| 604 |
+
new_size = [padded_n // block_n, k, block_n]
|
| 605 |
+
blocked_w = ir.Buffer(
|
| 606 |
+
W.get_name(), # Borrow the registered buffer name
|
| 607 |
+
ir.FixedLayout(
|
| 608 |
+
W.get_device(),
|
| 609 |
+
W.get_dtype(),
|
| 610 |
+
new_size,
|
| 611 |
+
ir.FlexibleLayout.contiguous_strides(new_size),
|
| 612 |
+
0,
|
| 613 |
+
),
|
| 614 |
+
)
|
| 615 |
+
else:
|
| 616 |
+
blocked_w = (
|
| 617 |
+
torch.nn.functional.pad(W, (0, padded_n - n))
|
| 618 |
+
.reshape(k, padded_n // block_n, block_n)
|
| 619 |
+
.transpose(0, 1)
|
| 620 |
+
.contiguous()
|
| 621 |
+
)
|
| 622 |
+
if micro_gemm.get_b_layout() != LayoutType.NORMAL:
|
| 623 |
+
layout_str = (
|
| 624 |
+
"VNNI4"
|
| 625 |
+
if micro_gemm.get_b_layout() == LayoutType.VNNI4
|
| 626 |
+
else "VNNI2"
|
| 627 |
+
)
|
| 628 |
+
assert micro_gemm.get_b_layout() in [
|
| 629 |
+
LayoutType.VNNI2,
|
| 630 |
+
LayoutType.VNNI4,
|
| 631 |
+
], f"We only support {layout_str} for now"
|
| 632 |
+
vnni_size = (
|
| 633 |
+
4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
| 634 |
+
)
|
| 635 |
+
assert (
|
| 636 |
+
k % vnni_size == 0
|
| 637 |
+
), f"k should be divisible by vnni_size for {layout_str} layout"
|
| 638 |
+
blocked_w = (
|
| 639 |
+
blocked_w.view(
|
| 640 |
+
padded_n // block_n, k // vnni_size, vnni_size, block_n
|
| 641 |
+
)
|
| 642 |
+
.transpose(-1, -2)
|
| 643 |
+
.contiguous()
|
| 644 |
+
.view(padded_n // block_n, k, block_n)
|
| 645 |
+
)
|
| 646 |
+
# normalize stride to be "contiguous_strides" per size
|
| 647 |
+
# this avoids the problems in L.view during template codegen
|
| 648 |
+
new_stride = [1]
|
| 649 |
+
for sz in reversed(blocked_w.shape[1:]):
|
| 650 |
+
new_stride.insert(0, new_stride[0] * sz)
|
| 651 |
+
blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride)
|
| 652 |
+
new_inputs[1] = blocked_w
|
| 653 |
+
|
| 654 |
+
def _is_int8_gemm(inputs):
|
| 655 |
+
return (
|
| 656 |
+
isinstance(inputs[0], ir.IRNode)
|
| 657 |
+
and inputs[0].get_dtype() == torch.uint8
|
| 658 |
+
) or (
|
| 659 |
+
isinstance(inputs[0], torch.Tensor)
|
| 660 |
+
and inputs[0].dtype == torch.uint8
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
if _is_int8_gemm(new_inputs):
|
| 664 |
+
BCompensate = None
|
| 665 |
+
if isinstance(W, ir.IRNode):
|
| 666 |
+
BCompensate = V.graph.add_tensor_constant(
|
| 667 |
+
V.graph.constants[W.get_name() + "_BMatrixCompens"],
|
| 668 |
+
W.get_name() + "_BMatrixCompens",
|
| 669 |
+
)
|
| 670 |
+
else:
|
| 671 |
+
BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
|
| 672 |
+
new_inputs.append(BCompensate)
|
| 673 |
+
return new_inputs, layout_or_out
|
| 674 |
+
|
| 675 |
+
def preprocessor(inputs, layout):
|
| 676 |
+
return pack_weight(
|
| 677 |
+
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
def postprocessor(output):
|
| 681 |
+
if isinstance(output, ir.TensorBox):
|
| 682 |
+
# prepack the weight as input to the template buffer
|
| 683 |
+
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
|
| 684 |
+
assert isinstance(template_buffer, ir.CppTemplateBuffer)
|
| 685 |
+
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
|
| 686 |
+
|
| 687 |
+
W_node = new_input_nodes[1]
|
| 688 |
+
assert W_node.get_name() in V.graph.constants
|
| 689 |
+
W = V.graph.constants[W_node.get_name()]
|
| 690 |
+
new_input_nodes[1] = W
|
| 691 |
+
new_input_nodes, _ = pack_weight(
|
| 692 |
+
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# By using the new packed weight for the GEMM template, we can prune the
|
| 696 |
+
# old weight if it has no other users. This saves memory but makes the FX graph
|
| 697 |
+
# non-retraceable. To support retracing, we can add a repack node to the
|
| 698 |
+
# FX graph. For example:
|
| 699 |
+
# mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
|
| 700 |
+
W_tensor_users = 0
|
| 701 |
+
for node in reversed(V.graph.graph.nodes):
|
| 702 |
+
# Case may happen when the wgt tensor is used by more than 1 get_attr node
|
| 703 |
+
# https://github.com/pytorch/pytorch/issues/134998
|
| 704 |
+
if node.op == "get_attr" and hasattr(
|
| 705 |
+
V.graph.module, node.name
|
| 706 |
+
): # wgt might already be deleted
|
| 707 |
+
comp_tensor = getattr(V.graph.module, node.name)
|
| 708 |
+
if (
|
| 709 |
+
W.is_mkldnn == comp_tensor.is_mkldnn
|
| 710 |
+
and W.dtype == comp_tensor.dtype
|
| 711 |
+
and W.device == comp_tensor.device
|
| 712 |
+
and (
|
| 713 |
+
(
|
| 714 |
+
not W.is_mkldnn
|
| 715 |
+
and (
|
| 716 |
+
W.untyped_storage().data_ptr()
|
| 717 |
+
== comp_tensor.untyped_storage().data_ptr()
|
| 718 |
+
)
|
| 719 |
+
)
|
| 720 |
+
or (
|
| 721 |
+
W.is_mkldnn
|
| 722 |
+
and (
|
| 723 |
+
torch.ops.mkldnn.data_ptr(W)
|
| 724 |
+
== torch.ops.mkldnn.data_ptr(comp_tensor)
|
| 725 |
+
)
|
| 726 |
+
)
|
| 727 |
+
)
|
| 728 |
+
):
|
| 729 |
+
W_tensor_users += 1
|
| 730 |
+
|
| 731 |
+
for node in reversed(V.graph.graph.nodes):
|
| 732 |
+
# The wgt tensor has been used by only 1 get_attr node
|
| 733 |
+
# The get_attr node has only 1 user fx node
|
| 734 |
+
if (
|
| 735 |
+
node.name == W_node.get_name()
|
| 736 |
+
and len(node.users) == 1
|
| 737 |
+
and W_tensor_users == 1
|
| 738 |
+
):
|
| 739 |
+
del V.graph.constants[node.name]
|
| 740 |
+
delattr(V.graph.module, node.name)
|
| 741 |
+
delattr(V.graph.graph.owning_module, node.name)
|
| 742 |
+
|
| 743 |
+
W_packed = new_input_nodes[1]
|
| 744 |
+
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
| 745 |
+
template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
|
| 746 |
+
W_packed_constant
|
| 747 |
+
)
|
| 748 |
+
return output
|
| 749 |
+
|
| 750 |
+
template = DataProcessorTemplateWrapper(
|
| 751 |
+
CppPackedGemmTemplate,
|
| 752 |
+
preprocessor,
|
| 753 |
+
postprocessor,
|
| 754 |
+
input_nodes=input_nodes,
|
| 755 |
+
layout=layout,
|
| 756 |
+
num_threads=num_threads,
|
| 757 |
+
register_blocking=micro_gemm.register_blocking,
|
| 758 |
+
beta=beta,
|
| 759 |
+
alpha=alpha,
|
| 760 |
+
has_bias=has_bias,
|
| 761 |
+
epilogue_creator=epilogue_creator,
|
| 762 |
+
)
|
| 763 |
+
template.maybe_append_choice(choices)
|
| 764 |
+
return template
|
| 765 |
+
|
| 766 |
+
def render( # type: ignore[override,return]
|
| 767 |
+
self,
|
| 768 |
+
kernel: CppTemplateKernel,
|
| 769 |
+
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
| 770 |
+
flag_template_buffer_has_other_users: Optional[bool] = None,
|
| 771 |
+
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
| 772 |
+
**kwargs,
|
| 773 |
+
) -> str:
|
| 774 |
+
assert len(self.input_nodes) >= 2
|
| 775 |
+
|
| 776 |
+
int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8
|
| 777 |
+
x_scale = None
|
| 778 |
+
x_zp = None
|
| 779 |
+
w_scale = None
|
| 780 |
+
w_zp = None
|
| 781 |
+
if int8_gemm:
|
| 782 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 783 |
+
bias_idx = 2 if self.has_bias else 1
|
| 784 |
+
inp = self.input_nodes[bias_idx] if self.has_bias else None
|
| 785 |
+
x_scale = self.input_nodes[bias_idx + 1]
|
| 786 |
+
x_zp = self.input_nodes[bias_idx + 2]
|
| 787 |
+
w_scale = self.input_nodes[bias_idx + 3]
|
| 788 |
+
w_zp = self.input_nodes[bias_idx + 4]
|
| 789 |
+
Y = self.output_node
|
| 790 |
+
else:
|
| 791 |
+
X, W = self.input_nodes[0], self.input_nodes[1]
|
| 792 |
+
Y = self.output_node
|
| 793 |
+
inp = self.input_nodes[2] if self.has_bias else None
|
| 794 |
+
|
| 795 |
+
template_buffer_has_other_users = None
|
| 796 |
+
|
| 797 |
+
if template_buffer_node is not None:
|
| 798 |
+
# Use the updated prepacked weight buffer
|
| 799 |
+
W = template_buffer_node.inputs[1]
|
| 800 |
+
Y = template_buffer_node
|
| 801 |
+
|
| 802 |
+
assert flag_template_buffer_has_other_users is not None
|
| 803 |
+
template_buffer_has_other_users = flag_template_buffer_has_other_users
|
| 804 |
+
|
| 805 |
+
template_buffer = Y
|
| 806 |
+
gemm_output_buffer = template_buffer
|
| 807 |
+
|
| 808 |
+
epilogues: List[ir.IRNode] = []
|
| 809 |
+
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
|
| 810 |
+
epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = []
|
| 811 |
+
fake_buffers: List[ir.Buffer] = []
|
| 812 |
+
Y_aliases: Set[str] = set()
|
| 813 |
+
|
| 814 |
+
use_local_acc = (
|
| 815 |
+
self.layout.dtype != torch.float
|
| 816 |
+
or template_buffer_has_other_users
|
| 817 |
+
or int8_gemm
|
| 818 |
+
or self.padded_n != self.n
|
| 819 |
+
or self.maybe_k_slicing()
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
|
| 823 |
+
# but we'd better move it here to align with fp.
|
| 824 |
+
if inp is not None and self.beta != 0 and not int8_gemm:
|
| 825 |
+
# add an epilogue for bias add
|
| 826 |
+
def _bias_add_epilogue(buf):
|
| 827 |
+
return create_epilogue_with_attr(
|
| 828 |
+
buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
epilogue_creators.append(_bias_add_epilogue)
|
| 832 |
+
|
| 833 |
+
if self.epilogue_creator is not None:
|
| 834 |
+
epilogue_creators.append(self.epilogue_creator)
|
| 835 |
+
|
| 836 |
+
# When the GEMM output buffer is localized but it has users other than the epilogue nodes,
|
| 837 |
+
# we need to copy the value in the GEMM output local buffer to a global buffer.
|
| 838 |
+
def need_copy_from_local_to_global_buffer_epilogue(
|
| 839 |
+
use_local_acc, template_buffer_has_other_users, epilogue_creators
|
| 840 |
+
):
|
| 841 |
+
# The GEMM output buffer is a global buffer, thus copy is not needed.
|
| 842 |
+
if not use_local_acc:
|
| 843 |
+
return False
|
| 844 |
+
|
| 845 |
+
# The possible value of template_buffer_has_other_users is (None, False, True)
|
| 846 |
+
# It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
|
| 847 |
+
# extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
|
| 848 |
+
# 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
|
| 849 |
+
# 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
|
| 850 |
+
# GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
|
| 851 |
+
if not template_buffer_has_other_users:
|
| 852 |
+
return False
|
| 853 |
+
|
| 854 |
+
# When bias is not None or self.epilogue_creator is not None,
|
| 855 |
+
# there will be epilogue_creators after the GEMM.
|
| 856 |
+
# The GEMM output buffer is localized while
|
| 857 |
+
# the output buffer of the epilogue_creators is a global buffer.
|
| 858 |
+
if epilogue_creators:
|
| 859 |
+
return False
|
| 860 |
+
|
| 861 |
+
return True
|
| 862 |
+
|
| 863 |
+
if need_copy_from_local_to_global_buffer_epilogue(
|
| 864 |
+
use_local_acc, template_buffer_has_other_users, epilogue_creators
|
| 865 |
+
):
|
| 866 |
+
|
| 867 |
+
def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
|
| 868 |
+
dtype = self.layout.dtype
|
| 869 |
+
input_loader = input_buffer.make_loader()
|
| 870 |
+
|
| 871 |
+
def copy_inner(index):
|
| 872 |
+
input = input_loader(index)
|
| 873 |
+
result = ops.to_dtype(input, dtype)
|
| 874 |
+
return result
|
| 875 |
+
|
| 876 |
+
return ir.Pointwise(
|
| 877 |
+
device=input_buffer.get_device(),
|
| 878 |
+
dtype=self.layout.dtype,
|
| 879 |
+
inner_fn=copy_inner,
|
| 880 |
+
ranges=input_buffer.get_size(),
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
|
| 884 |
+
|
| 885 |
+
# NOTE [How CPP GEMM template epilogues are organized]
|
| 886 |
+
# gemm_output_buffer
|
| 887 |
+
# --> zero or more in-template epilogues (created by `epilogue_creators`) -->
|
| 888 |
+
# template_buffer
|
| 889 |
+
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
|
| 890 |
+
# Y
|
| 891 |
+
if epilogue_creators:
|
| 892 |
+
gemm_output_name = "buf_GemmOut"
|
| 893 |
+
gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
|
| 894 |
+
current_input_buffer = gemm_output_buffer
|
| 895 |
+
for i, creator in enumerate(epilogue_creators):
|
| 896 |
+
if i == len(epilogue_creators) - 1:
|
| 897 |
+
buffer_name = template_buffer.get_name()
|
| 898 |
+
else:
|
| 899 |
+
buffer_name = f"buf_GemmOut_epilogue_{i}"
|
| 900 |
+
epilogues.append(
|
| 901 |
+
ir.ComputedBuffer(
|
| 902 |
+
name=buffer_name,
|
| 903 |
+
layout=template_buffer.layout,
|
| 904 |
+
data=creator(current_input_buffer),
|
| 905 |
+
)
|
| 906 |
+
)
|
| 907 |
+
fake_buffers.append(current_input_buffer)
|
| 908 |
+
Y_aliases.add(current_input_buffer.get_name())
|
| 909 |
+
reindexers.append(None)
|
| 910 |
+
if i < len(epilogue_creators) - 1:
|
| 911 |
+
current_input_buffer = ir.Buffer(
|
| 912 |
+
buffer_name, template_buffer.layout
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
| 916 |
+
|
| 917 |
+
if epilogue_nodes:
|
| 918 |
+
epilogues.extend(epilogue_nodes)
|
| 919 |
+
assert Y.get_numel() == epilogues[-1].get_numel()
|
| 920 |
+
Y = cast(ir.Buffer, epilogues[-1])
|
| 921 |
+
|
| 922 |
+
if not template_buffer_has_other_users:
|
| 923 |
+
Y_aliases.add(template_buffer.get_name())
|
| 924 |
+
|
| 925 |
+
if (
|
| 926 |
+
Y.get_size() == template_buffer.get_size()
|
| 927 |
+
and Y.get_stride() == template_buffer.get_stride()
|
| 928 |
+
):
|
| 929 |
+
reindexers.extend([None] * len(epilogue_nodes))
|
| 930 |
+
Y_2d = Y
|
| 931 |
+
else:
|
| 932 |
+
|
| 933 |
+
def get_reindexer(epilogue_node):
|
| 934 |
+
# From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
|
| 935 |
+
# template_buffer:
|
| 936 |
+
# size (324, 512), stride (512, 1)
|
| 937 |
+
# epilogue_node_ordered (ordered by stride decreasingly, in dense format):
|
| 938 |
+
# size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
|
| 939 |
+
stride_order = list(
|
| 940 |
+
ir.get_stride_order(
|
| 941 |
+
V.graph.sizevars.size_hints(epilogue_node.get_stride())
|
| 942 |
+
)
|
| 943 |
+
)
|
| 944 |
+
fill_order = ir.stride_order2fill_order(stride_order)
|
| 945 |
+
reversed_fill_order = list(reversed(fill_order))
|
| 946 |
+
size_with_stride_ordered_decreasingly = [
|
| 947 |
+
epilogue_node.get_size()[i] for i in reversed_fill_order
|
| 948 |
+
]
|
| 949 |
+
reshape_reindex = ir.View.dynamic_reshape_indexer(
|
| 950 |
+
size_with_stride_ordered_decreasingly,
|
| 951 |
+
template_buffer.get_size(),
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
# From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
|
| 955 |
+
# epilogue_node_ordered (ordered by stride decreasingly, in dense format):
|
| 956 |
+
# size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
|
| 957 |
+
# epilogue_node:
|
| 958 |
+
# size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
|
| 959 |
+
from_stride_ordered_decreasingly_to_epilogue_node_order = [
|
| 960 |
+
(len(stride_order) - 1) - stride_order[i]
|
| 961 |
+
for i in range(len(stride_order))
|
| 962 |
+
]
|
| 963 |
+
stride_reindex = ir.same_reorder(
|
| 964 |
+
from_stride_ordered_decreasingly_to_epilogue_node_order
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex)
|
| 968 |
+
return reindexer
|
| 969 |
+
|
| 970 |
+
reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item]
|
| 971 |
+
if isinstance(Y, ir.BaseView):
|
| 972 |
+
storage = ir.StorageBox(Y.unwrap_view())
|
| 973 |
+
else:
|
| 974 |
+
assert isinstance(Y, ir.Buffer)
|
| 975 |
+
storage = ir.StorageBox(Y)
|
| 976 |
+
Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout())
|
| 977 |
+
|
| 978 |
+
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
| 979 |
+
X.get_dtype()
|
| 980 |
+
)
|
| 981 |
+
micro_gemm = create_micro_gemm(
|
| 982 |
+
f"{kernel.kernel_name}_micro_gemm",
|
| 983 |
+
self.m,
|
| 984 |
+
self.n,
|
| 985 |
+
self.k,
|
| 986 |
+
input_dtype=X.get_dtype(),
|
| 987 |
+
input2_dtype=W.get_dtype(),
|
| 988 |
+
output_dtype=output_dtype,
|
| 989 |
+
compute_dtype=compute_dtype,
|
| 990 |
+
alpha=self.alpha,
|
| 991 |
+
num_threads=self.num_threads,
|
| 992 |
+
)
|
| 993 |
+
assert micro_gemm is not None
|
| 994 |
+
assert self.register_blocking == micro_gemm.register_blocking
|
| 995 |
+
self.log_blockings()
|
| 996 |
+
if isinstance(micro_gemm, CppMicroGemmAMX):
|
| 997 |
+
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
| 998 |
+
|
| 999 |
+
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
| 1000 |
+
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
| 1001 |
+
|
| 1002 |
+
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
| 1003 |
+
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
| 1004 |
+
|
| 1005 |
+
options = dict(
|
| 1006 |
+
X=X,
|
| 1007 |
+
W=W,
|
| 1008 |
+
inp=inp,
|
| 1009 |
+
Y=Y,
|
| 1010 |
+
N=self.n,
|
| 1011 |
+
K=self.k,
|
| 1012 |
+
PADDED_N=self.padded_n,
|
| 1013 |
+
GemmOut=gemm_output_buffer,
|
| 1014 |
+
aliases={alias: Y.get_name() for alias in Y_aliases},
|
| 1015 |
+
beta=self.beta,
|
| 1016 |
+
alpha=self.alpha,
|
| 1017 |
+
num_threads=self.num_threads,
|
| 1018 |
+
micro_gemm=micro_gemm,
|
| 1019 |
+
is_dynamic_M=self.is_dynamic_M,
|
| 1020 |
+
template=self,
|
| 1021 |
+
kernel=kernel,
|
| 1022 |
+
export_declaration=get_export_declaration(),
|
| 1023 |
+
epilogue_nodes=epilogues,
|
| 1024 |
+
reindexers=reindexers,
|
| 1025 |
+
Y_2d=Y_2d,
|
| 1026 |
+
use_local_acc=use_local_acc,
|
| 1027 |
+
maybe_k_slicing=self.maybe_k_slicing(),
|
| 1028 |
+
x_scale=x_scale,
|
| 1029 |
+
x_zp=x_zp,
|
| 1030 |
+
w_scale=w_scale,
|
| 1031 |
+
w_zp=w_zp,
|
| 1032 |
+
acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
|
| 1033 |
+
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
| 1034 |
+
L1_cache_size=L1_cache_size,
|
| 1035 |
+
L2_cache_size=L2_cache_size,
|
| 1036 |
+
config=config,
|
| 1037 |
+
)
|
| 1038 |
+
with contextlib.ExitStack() as stack:
|
| 1039 |
+
for buf in fake_buffers:
|
| 1040 |
+
stack.enter_context(
|
| 1041 |
+
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
| 1042 |
+
)
|
| 1043 |
+
return self._template_from_string(GEMM_TEMPLATE).render(**options)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py
ADDED
|
File without changes
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 34 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 36 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 38 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 39 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 40 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 41 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 42 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 43 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 44 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 45 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 46 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 47 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 48 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 49 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 50 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 51 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 52 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 53 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 54 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 55 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 56 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 57 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
| 58 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 59 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 60 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
| 61 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 62 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 63 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 64 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 65 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 66 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 67 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 68 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 69 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 70 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 71 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 72 |
+
_sfdp_pattern_1_training = MultiOutputPattern([view_default_5,
|
| 73 |
+
view_default_9,
|
| 74 |
+
permute_default_4,
|
| 75 |
+
view_default_11,
|
| 76 |
+
None
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 81 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 82 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 83 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 84 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 85 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 86 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 87 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 88 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 89 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 90 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 91 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 92 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 93 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 94 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 95 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 96 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 97 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 98 |
+
_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 102 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 103 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 104 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 105 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 106 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 107 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 108 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 109 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 110 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 111 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 112 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 113 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 114 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 115 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 116 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 117 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 118 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 120 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 121 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 122 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 123 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 124 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 125 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 126 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 127 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 128 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 129 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
|
| 130 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 131 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 132 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 133 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
|
| 134 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 135 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 136 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 137 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 138 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 139 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 140 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 141 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 142 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 143 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 144 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 145 |
+
_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5,
|
| 146 |
+
view_default_9,
|
| 147 |
+
permute_default_4,
|
| 148 |
+
view_default_11,
|
| 149 |
+
None
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 154 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 155 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 156 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 157 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 158 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 159 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 160 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 161 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 162 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 163 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 164 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 165 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 166 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 167 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 168 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 169 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 170 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 171 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 172 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 173 |
+
_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 34 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 36 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 37 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 38 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 39 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 40 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 41 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 42 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 43 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 44 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 45 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 46 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 47 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 48 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 49 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 50 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 54 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 55 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 56 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 57 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 58 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 59 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 60 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 61 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 62 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 63 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 64 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 65 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 66 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
| 67 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 68 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 69 |
+
view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
|
| 70 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 71 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 72 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 73 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 74 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 75 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 76 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 77 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 78 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 79 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 80 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 81 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 82 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 83 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 84 |
+
_sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
|
| 85 |
+
permute_default_6,
|
| 86 |
+
permute_default_9,
|
| 87 |
+
permute_default_11
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 92 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 93 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 94 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 95 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 96 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 97 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 98 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 99 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 100 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 101 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 102 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 103 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 104 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 105 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 106 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 107 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 108 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 109 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 110 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 111 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 112 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 113 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 114 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 115 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 116 |
+
_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 120 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 121 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 122 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 123 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 124 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 125 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 126 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 127 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 128 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 129 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 130 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 131 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 132 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 133 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 134 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 135 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 136 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 137 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 138 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 139 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 140 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 141 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 142 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 143 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 144 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 145 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 146 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 147 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 148 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 149 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 150 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 151 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 152 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
| 153 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 154 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 155 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 156 |
+
view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2)
|
| 157 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 158 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 159 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 160 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 161 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 162 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 163 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 164 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 165 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 166 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 167 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 168 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 169 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 170 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 171 |
+
_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5,
|
| 172 |
+
permute_default_6,
|
| 173 |
+
permute_default_9,
|
| 174 |
+
permute_default_11
|
| 175 |
+
])
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 179 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 180 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 181 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 182 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 183 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 184 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 185 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 186 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 187 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 188 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 189 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 190 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 191 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 192 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 193 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 194 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 195 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 196 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 197 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 198 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 199 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 200 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 201 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 202 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 203 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 204 |
+
_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 35 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 36 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 37 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 40 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 41 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 42 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 43 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 44 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 45 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 46 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 47 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 48 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 49 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 50 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 51 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 52 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 53 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 54 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 55 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 56 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 57 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 58 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 59 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 60 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 61 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 62 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 63 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
| 64 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 65 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 66 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
| 67 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 68 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 69 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 70 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 71 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 72 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 73 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 74 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 75 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 76 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 77 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 78 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 79 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 80 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 81 |
+
_sfdp_pattern_11_training = MultiOutputPattern([view_default_5,
|
| 82 |
+
permute_default_6,
|
| 83 |
+
permute_default_9,
|
| 84 |
+
permute_default_11,
|
| 85 |
+
None
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 90 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 91 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 92 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 93 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 94 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 95 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 96 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 97 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 98 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 99 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 100 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
| 101 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 102 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 103 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 104 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 105 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 106 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 107 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 108 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 109 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 110 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 111 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 112 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 113 |
+
_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 117 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 118 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 119 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 120 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 121 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 122 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 123 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 124 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 125 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 126 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 127 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 128 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 129 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 130 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 131 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 132 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 133 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 134 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 135 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 136 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 137 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 138 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 139 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 140 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 141 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 142 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 143 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 144 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 145 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 146 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 147 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 148 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 149 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 150 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
|
| 151 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 152 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 153 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 154 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
|
| 155 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 156 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 157 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 158 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 159 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 160 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 161 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 162 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 163 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 164 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 165 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 166 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 167 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 168 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 169 |
+
_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5,
|
| 170 |
+
permute_default_6,
|
| 171 |
+
permute_default_9,
|
| 172 |
+
permute_default_11,
|
| 173 |
+
None
|
| 174 |
+
])
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 178 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 179 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 180 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 181 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 182 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 183 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 184 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 185 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 186 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 187 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 188 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 189 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 190 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 191 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 192 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 193 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 194 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 195 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 196 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 197 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 198 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 199 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 200 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 201 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 202 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 203 |
+
_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 38 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 39 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 40 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 41 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 42 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 43 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 44 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 45 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 46 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 47 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 48 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 49 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 50 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 51 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 52 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 53 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 54 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 55 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 56 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 57 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 58 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 59 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 60 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 61 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 62 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 63 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 64 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 65 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 66 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 67 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 68 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 69 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 70 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 71 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 72 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 73 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
|
| 74 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 75 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 76 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 77 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 78 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 79 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 80 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 81 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 82 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 83 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 84 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 85 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 86 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 87 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 88 |
+
_sfdp_pattern_12_training = MultiOutputPattern([view_default_5,
|
| 89 |
+
permute_default_6,
|
| 90 |
+
permute_default_9,
|
| 91 |
+
permute_default_11,
|
| 92 |
+
None,
|
| 93 |
+
None
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 98 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 99 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 100 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 101 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 102 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 103 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 104 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 105 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 106 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 107 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 108 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 109 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 110 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 111 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 112 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 113 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 114 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 115 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 116 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 117 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 118 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 120 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 121 |
+
_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 125 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 126 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 127 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 128 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 129 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 130 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 131 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 132 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 133 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 134 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 135 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 136 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 137 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 138 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 139 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 140 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 141 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 142 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 143 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 144 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 145 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 146 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 147 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 148 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 149 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 150 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 151 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 152 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 153 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 154 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 155 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 156 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 157 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 158 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 159 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 160 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 161 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 162 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 163 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 164 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 165 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 166 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 167 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 168 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 169 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
|
| 170 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 171 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 172 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 173 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 174 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 175 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 176 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 177 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 178 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 179 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 180 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 181 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 182 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 183 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 184 |
+
_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5,
|
| 185 |
+
permute_default_6,
|
| 186 |
+
permute_default_9,
|
| 187 |
+
permute_default_11,
|
| 188 |
+
None,
|
| 189 |
+
None
|
| 190 |
+
])
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 194 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 195 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 196 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 197 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 198 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 199 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 200 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 201 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 202 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 203 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 204 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 205 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 206 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 207 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 208 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 209 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 210 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 211 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 212 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 213 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 214 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 215 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 216 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 217 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 218 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 219 |
+
_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 36 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
|
| 37 |
+
amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
|
| 38 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
|
| 39 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 40 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 41 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 42 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
|
| 43 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
|
| 44 |
+
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
|
| 45 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
| 46 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 47 |
+
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
|
| 48 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 49 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 50 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
|
| 51 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2)
|
| 52 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 53 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2)
|
| 54 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
|
| 55 |
+
bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2)
|
| 56 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 57 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default)
|
| 58 |
+
permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
|
| 59 |
+
permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
|
| 60 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
|
| 61 |
+
_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1,
|
| 62 |
+
bmm_default_3,
|
| 63 |
+
permute_default_4,
|
| 64 |
+
bmm_default_5,
|
| 65 |
+
None
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 70 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
|
| 71 |
+
amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
|
| 72 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
|
| 73 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 74 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 75 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 76 |
+
_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 80 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 81 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 82 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
|
| 83 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
|
| 84 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 85 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 86 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 87 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 88 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 89 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 90 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 91 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
|
| 92 |
+
bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
|
| 93 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 94 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 95 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 96 |
+
bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
|
| 97 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 98 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 99 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
|
| 100 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 101 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 102 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 103 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 104 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2)
|
| 105 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
|
| 106 |
+
bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2)
|
| 107 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 108 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5)
|
| 109 |
+
permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
|
| 110 |
+
permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
|
| 111 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
|
| 112 |
+
_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1,
|
| 113 |
+
bmm_default_3,
|
| 114 |
+
permute_default_4,
|
| 115 |
+
bmm_default_5,
|
| 116 |
+
None
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 121 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
|
| 122 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
|
| 123 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 124 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 125 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 126 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 127 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 128 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
| 129 |
+
_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 35 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 36 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 37 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 40 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 41 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 42 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 43 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 44 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 45 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 46 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 47 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 48 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 49 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 50 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 54 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 55 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 56 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 57 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 58 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 59 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 60 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 61 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 62 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 63 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 64 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
| 65 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 66 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 67 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
| 68 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 69 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 70 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 71 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 72 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 73 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 74 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 75 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 76 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 77 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 78 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 79 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 80 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 81 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 82 |
+
_sfdp_pattern_14_training = MultiOutputPattern([view_default_5,
|
| 83 |
+
permute_default_6,
|
| 84 |
+
permute_default_9,
|
| 85 |
+
permute_default_11,
|
| 86 |
+
None,
|
| 87 |
+
None
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 92 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 93 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 94 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 95 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 96 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 97 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 98 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 99 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 100 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 101 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 102 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 103 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 104 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 105 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 106 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 107 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 108 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 109 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 110 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 111 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 112 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 113 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 114 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 115 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 116 |
+
_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 120 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 121 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 122 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 123 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 124 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 125 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 126 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 127 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 128 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 129 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 130 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 131 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 132 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 133 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 134 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 135 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 136 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 137 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 138 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 139 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 140 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 141 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 142 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 143 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 144 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 145 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 146 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 147 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 148 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 149 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 150 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 151 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 152 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 153 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 154 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
|
| 155 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 156 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 157 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 158 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
|
| 159 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 160 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 161 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 162 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 163 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 164 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 165 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 166 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 167 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 168 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 169 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 170 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 171 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 172 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 173 |
+
_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5,
|
| 174 |
+
permute_default_6,
|
| 175 |
+
permute_default_9,
|
| 176 |
+
permute_default_11,
|
| 177 |
+
None,
|
| 178 |
+
None
|
| 179 |
+
])
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 183 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 184 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 185 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 186 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 187 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 188 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 189 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 190 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 191 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 192 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 193 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 194 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 195 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 196 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 197 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 198 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 199 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 200 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 201 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 202 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 203 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 204 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 205 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 206 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 207 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 208 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 209 |
+
_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 35 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 36 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 37 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 38 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 39 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 40 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 41 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 42 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 43 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 44 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 45 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 46 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 47 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 48 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 49 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 50 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 51 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 52 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 53 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 54 |
+
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 55 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 56 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 57 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 58 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 59 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 60 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 61 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 62 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 63 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 68 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
| 69 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 70 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 71 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
|
| 72 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 73 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 74 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 75 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 76 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 77 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 78 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 79 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 80 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 81 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 82 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 83 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 84 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 85 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 86 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 87 |
+
_sfdp_pattern_15_training = MultiOutputPattern([view_default_5,
|
| 88 |
+
permute_default_6,
|
| 89 |
+
permute_default_9,
|
| 90 |
+
permute_default_11,
|
| 91 |
+
None,
|
| 92 |
+
None
|
| 93 |
+
])
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 97 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 98 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 99 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 100 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 101 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 102 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 103 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 104 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 105 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 106 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 107 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 108 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 109 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 110 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 111 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 112 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 113 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 114 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 115 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 116 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 117 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 118 |
+
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 120 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 121 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 122 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 123 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 124 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 125 |
+
_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 129 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 130 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 131 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 132 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 133 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 134 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 135 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 136 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 137 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 138 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 139 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 140 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 141 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 142 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 143 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 144 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 145 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 146 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 147 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 148 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 149 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 150 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 151 |
+
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 152 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 153 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 154 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 155 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 156 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 157 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 158 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 159 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 160 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 161 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 162 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 163 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 164 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 165 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 166 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 167 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
|
| 168 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 169 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 170 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 171 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4)
|
| 172 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 173 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 174 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 175 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 176 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 177 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 178 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 179 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 180 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 181 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 182 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 183 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 184 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 185 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 186 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 187 |
+
_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5,
|
| 188 |
+
permute_default_6,
|
| 189 |
+
permute_default_9,
|
| 190 |
+
permute_default_11,
|
| 191 |
+
None,
|
| 192 |
+
None
|
| 193 |
+
])
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 197 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 198 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 199 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 200 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 201 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 202 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 203 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 204 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 205 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 206 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 207 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 208 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 209 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 210 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 211 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 212 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 213 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 214 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 215 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 216 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 217 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 218 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 219 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 220 |
+
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 221 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 222 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 223 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 224 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 225 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 226 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 227 |
+
_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 38 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 39 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 40 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 41 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 42 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 43 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 44 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 45 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 46 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 47 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 48 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 49 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 50 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 51 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 52 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 53 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 54 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 55 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 56 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 57 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 58 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 59 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 60 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 61 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 62 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 63 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 68 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 69 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 70 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 71 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 72 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 73 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 74 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
| 75 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 76 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 77 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 78 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 79 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 80 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 81 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 82 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 83 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 84 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 85 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 86 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 87 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 88 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 89 |
+
_sfdp_pattern_16_training = MultiOutputPattern([view_default_5,
|
| 90 |
+
permute_default_6,
|
| 91 |
+
permute_default_9,
|
| 92 |
+
permute_default_11,
|
| 93 |
+
None,
|
| 94 |
+
None,
|
| 95 |
+
None
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 100 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 101 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 102 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 103 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 104 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 105 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 106 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 107 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 108 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 109 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 110 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 111 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 112 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 113 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 114 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 115 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 116 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 117 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 118 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 119 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 120 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 121 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 122 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 123 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 124 |
+
_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 128 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 129 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 130 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 131 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 132 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 133 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 134 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 135 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 136 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 137 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 138 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 139 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 140 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 141 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 142 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 143 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 144 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 145 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 146 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 147 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 148 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 149 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 150 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 151 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 152 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 153 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 154 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 155 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 156 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 157 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 158 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 159 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 160 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 161 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 162 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 163 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 164 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 165 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
|
| 166 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 167 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 168 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 169 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 170 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 171 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 172 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 173 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 174 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 175 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 176 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 177 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 178 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 179 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 180 |
+
_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5,
|
| 181 |
+
permute_default_6,
|
| 182 |
+
permute_default_9,
|
| 183 |
+
permute_default_11,
|
| 184 |
+
None,
|
| 185 |
+
None,
|
| 186 |
+
None
|
| 187 |
+
])
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 191 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 192 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 193 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 194 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 195 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 196 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 197 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 198 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 199 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 200 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 201 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 202 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 203 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 204 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 205 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 206 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 207 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 208 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 209 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 210 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 211 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 212 |
+
_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 216 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 217 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 218 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 219 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 220 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 221 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 222 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 223 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 224 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 225 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 226 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 227 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 228 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 229 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 230 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 231 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 232 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 233 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 234 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 235 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 236 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 237 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 238 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 239 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 240 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 241 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 242 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 243 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 244 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 245 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 246 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 247 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 248 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 249 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 250 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 251 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 252 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 253 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 254 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 255 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 256 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 257 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 258 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 259 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 260 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 261 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
|
| 262 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 263 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 264 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 265 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 266 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 267 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 268 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 269 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 270 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 271 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 272 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 273 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 274 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 275 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 276 |
+
_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5,
|
| 277 |
+
permute_default_6,
|
| 278 |
+
permute_default_9,
|
| 279 |
+
permute_default_11,
|
| 280 |
+
None,
|
| 281 |
+
None,
|
| 282 |
+
None
|
| 283 |
+
])
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 287 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 288 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 289 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 290 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 291 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 292 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 293 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 294 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 295 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 296 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 297 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 298 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 299 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 300 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 301 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 302 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 303 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 304 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 305 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 306 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 307 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 308 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 309 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 310 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 311 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 312 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 313 |
+
_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 317 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 318 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 319 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 320 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 321 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 322 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 323 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 324 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 325 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 326 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 327 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 328 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 329 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 330 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 331 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 332 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 333 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 334 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 335 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 336 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 337 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 338 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 339 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 340 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 341 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 342 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 343 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 344 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 345 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 346 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 347 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 348 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 349 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 350 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 351 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 352 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 353 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 354 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 355 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 356 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 357 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 358 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 359 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
|
| 360 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 361 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 362 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 363 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 364 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 365 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 366 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 367 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 368 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 369 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 370 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 371 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 372 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 373 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 374 |
+
_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5,
|
| 375 |
+
permute_default_6,
|
| 376 |
+
permute_default_9,
|
| 377 |
+
permute_default_11,
|
| 378 |
+
None,
|
| 379 |
+
None,
|
| 380 |
+
None
|
| 381 |
+
])
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 385 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 386 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 387 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 388 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 389 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 390 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 391 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 392 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 393 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 394 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 395 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 396 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 397 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 398 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 399 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 400 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 401 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 402 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 403 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 404 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 405 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 406 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 407 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 408 |
+
_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 412 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 413 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 414 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 415 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 416 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 417 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 418 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 419 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 420 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 421 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 422 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 423 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 424 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 425 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 426 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 427 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 428 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 429 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 430 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 431 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 432 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 433 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 434 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 435 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 436 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 437 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 438 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 439 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 440 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 441 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 442 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 443 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 444 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 445 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 446 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 447 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 448 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 449 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 450 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
|
| 451 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 452 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 453 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 454 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 455 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
|
| 456 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 457 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 458 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 459 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 460 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 461 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 462 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 463 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 464 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 465 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 466 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 467 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 468 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 469 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 470 |
+
_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5,
|
| 471 |
+
permute_default_6,
|
| 472 |
+
permute_default_9,
|
| 473 |
+
permute_default_11,
|
| 474 |
+
None,
|
| 475 |
+
None,
|
| 476 |
+
None
|
| 477 |
+
])
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 481 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 482 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 483 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 484 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 485 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 486 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 487 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 488 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 489 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 490 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 491 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 492 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 493 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 494 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 495 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 496 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 497 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 498 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 499 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 500 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 501 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 502 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 503 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 504 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 505 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 506 |
+
_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 510 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 511 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 512 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 513 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 514 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 515 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 516 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 517 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 518 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 519 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 520 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 521 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 522 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 523 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 524 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 525 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 526 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 527 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 528 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 529 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 530 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 531 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 532 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 533 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 534 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 535 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 536 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 537 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 538 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 539 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 540 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 541 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 542 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 543 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 544 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
|
| 545 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
|
| 546 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 547 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 548 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 549 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 550 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
|
| 551 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 552 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 553 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 554 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 555 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 556 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 557 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 558 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 559 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 560 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 561 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 562 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 563 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 564 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 565 |
+
_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5,
|
| 566 |
+
permute_default_6,
|
| 567 |
+
permute_default_9,
|
| 568 |
+
permute_default_11,
|
| 569 |
+
None,
|
| 570 |
+
None,
|
| 571 |
+
None
|
| 572 |
+
])
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 576 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 577 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 578 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 579 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 580 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 581 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 582 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 583 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 584 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 585 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 586 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 587 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 588 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 589 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 590 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 591 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 592 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 593 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 594 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 595 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 596 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 597 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 598 |
+
_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 37 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 38 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 40 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 41 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 42 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 43 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 44 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 45 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 46 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 47 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 48 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 49 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 50 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 51 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 52 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 53 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 54 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 55 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 56 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 57 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 58 |
+
expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 59 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 60 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 61 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 62 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 63 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 64 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 65 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 66 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 67 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 68 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 69 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 70 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 71 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 72 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 73 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 74 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 75 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 76 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 77 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 78 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
|
| 79 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 80 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 81 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 82 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 83 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 84 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 85 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 86 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 87 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 88 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 89 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 90 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 91 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 92 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 93 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 94 |
+
_sfdp_pattern_17_training = MultiOutputPattern([view_default_5,
|
| 95 |
+
permute_default_6,
|
| 96 |
+
permute_default_9,
|
| 97 |
+
permute_default_11,
|
| 98 |
+
None,
|
| 99 |
+
None,
|
| 100 |
+
None
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 105 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 106 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 107 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 108 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 109 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 110 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 111 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 112 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 113 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 114 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 115 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 116 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 117 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 118 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 119 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 120 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
|
| 121 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 122 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 123 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 124 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 125 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 126 |
+
expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 127 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 128 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 129 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 130 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 131 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 132 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 133 |
+
_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 137 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 138 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 139 |
+
expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
|
| 140 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 141 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 142 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 143 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 144 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 145 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 146 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 147 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 148 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 149 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 150 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 151 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 152 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
|
| 153 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 154 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 155 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 156 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 157 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 158 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 159 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 160 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 161 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 162 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 163 |
+
expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 164 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 165 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 166 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 167 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 168 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 169 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 170 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 171 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 172 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 173 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 174 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 175 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 176 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 177 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 178 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 179 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 180 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 181 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 182 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 183 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 184 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 185 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 186 |
+
where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5)
|
| 187 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
|
| 188 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 189 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 190 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 191 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 192 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 193 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 194 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 195 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 196 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 197 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 198 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 199 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 200 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 201 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 202 |
+
_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5,
|
| 203 |
+
permute_default_6,
|
| 204 |
+
permute_default_9,
|
| 205 |
+
permute_default_11,
|
| 206 |
+
None,
|
| 207 |
+
None,
|
| 208 |
+
None
|
| 209 |
+
])
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
|
| 213 |
+
view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
|
| 214 |
+
expand_default = CallFunction(aten.expand.default, view_default, Ignored())
|
| 215 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 216 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 217 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 218 |
+
clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 219 |
+
view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
|
| 220 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 221 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 222 |
+
expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 223 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
|
| 224 |
+
view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 225 |
+
bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
|
| 226 |
+
view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 227 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
|
| 228 |
+
where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
|
| 229 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 230 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 231 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 232 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 233 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 234 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 235 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 236 |
+
expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 237 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 238 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 239 |
+
expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 240 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
|
| 241 |
+
view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 242 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
|
| 243 |
+
_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 38 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 39 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 40 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 41 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 42 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 43 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 44 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 45 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 46 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
|
| 47 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 48 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 49 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
|
| 50 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 51 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 52 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 53 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 54 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 55 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 56 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 57 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 58 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 59 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 60 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 61 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 62 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 63 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 64 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 65 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 66 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 67 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 68 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 69 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 70 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 71 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 72 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 73 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 74 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 75 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 76 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 77 |
+
where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default)
|
| 78 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
|
| 79 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 80 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 81 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 82 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 83 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 84 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 85 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 86 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 87 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 88 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 89 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 90 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 91 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 92 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 93 |
+
_sfdp_pattern_18_training = MultiOutputPattern([view_default_5,
|
| 94 |
+
permute_default_1,
|
| 95 |
+
permute_default_3,
|
| 96 |
+
permute_default_6,
|
| 97 |
+
permute_default_9,
|
| 98 |
+
permute_default_11,
|
| 99 |
+
None,
|
| 100 |
+
None
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 105 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 106 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 107 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 108 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 109 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 110 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 111 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 112 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 113 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 114 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 115 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 116 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 117 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 118 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
|
| 119 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 120 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 121 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 122 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 123 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 124 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 125 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 126 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 127 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 128 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 129 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 130 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 131 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 132 |
+
_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5,
|
| 133 |
+
permute_default_1,
|
| 134 |
+
permute_default_3
|
| 135 |
+
])
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 139 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 140 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 141 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 142 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 143 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 144 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 145 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 146 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 147 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 148 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 149 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
|
| 150 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 151 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 152 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
|
| 153 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 154 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 155 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 156 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 157 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 158 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 159 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 160 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 161 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 162 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 163 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 164 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 165 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 166 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 167 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 168 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 169 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 170 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 171 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 172 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 173 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 174 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 175 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 176 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 177 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 178 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 179 |
+
where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default)
|
| 180 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
|
| 181 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 182 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 183 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 184 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 185 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 186 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 187 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 188 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 189 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 190 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 191 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 192 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 193 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 194 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 195 |
+
_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5,
|
| 196 |
+
permute_default_1,
|
| 197 |
+
permute_default_3,
|
| 198 |
+
permute_default_6,
|
| 199 |
+
permute_default_9,
|
| 200 |
+
permute_default_11,
|
| 201 |
+
None,
|
| 202 |
+
None
|
| 203 |
+
])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 207 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 208 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 209 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 210 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 211 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 212 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 213 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 214 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 215 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 216 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 217 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 218 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
|
| 219 |
+
amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
|
| 220 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
|
| 221 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 222 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 223 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 224 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 225 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 226 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 227 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 228 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 229 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 230 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 231 |
+
_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5,
|
| 232 |
+
permute_default_1,
|
| 233 |
+
permute_default_3
|
| 234 |
+
])
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 238 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 239 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 240 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 241 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 242 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 243 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 244 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 245 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 246 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 247 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 248 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 249 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 250 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
|
| 251 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 252 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 253 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 254 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 255 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 256 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 257 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 258 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 259 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 260 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 261 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 262 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 263 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 264 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 265 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 266 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 267 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 268 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 269 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 270 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 271 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 272 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 273 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 274 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 275 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 276 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 277 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 278 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 279 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 280 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 281 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 282 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 283 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 284 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 285 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 286 |
+
where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default)
|
| 287 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
|
| 288 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 289 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 290 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 291 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 292 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 293 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 294 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 295 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 296 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 297 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 298 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 299 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 300 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 301 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 302 |
+
_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5,
|
| 303 |
+
permute_default_1,
|
| 304 |
+
permute_default_3,
|
| 305 |
+
permute_default_6,
|
| 306 |
+
permute_default_9,
|
| 307 |
+
permute_default_11,
|
| 308 |
+
None,
|
| 309 |
+
None
|
| 310 |
+
])
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 314 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 315 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 316 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 317 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 318 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 319 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 320 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 321 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 322 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 323 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 324 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 325 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 326 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 327 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 328 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 329 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 330 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 331 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 332 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 333 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 334 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 335 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 336 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 337 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 338 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 339 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 340 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 341 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 342 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 343 |
+
_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5,
|
| 344 |
+
permute_default_1,
|
| 345 |
+
permute_default_3
|
| 346 |
+
])
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 350 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 351 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 352 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 353 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 354 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 355 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 356 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 357 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 358 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 359 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 360 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
|
| 361 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 362 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 363 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 364 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 365 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 366 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 367 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 368 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 369 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 370 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 371 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 372 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 373 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 374 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 375 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 376 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 377 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 378 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 379 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 380 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 381 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 382 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 383 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 384 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 385 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 386 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 387 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 388 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 389 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 390 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 391 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 392 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 393 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 394 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 395 |
+
where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default)
|
| 396 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
|
| 397 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 398 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 399 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 400 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 401 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 402 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 403 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 404 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 405 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 406 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 407 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 408 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 409 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 410 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 411 |
+
_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5,
|
| 412 |
+
permute_default_1,
|
| 413 |
+
permute_default_3,
|
| 414 |
+
permute_default_6,
|
| 415 |
+
permute_default_9,
|
| 416 |
+
permute_default_11,
|
| 417 |
+
None,
|
| 418 |
+
None
|
| 419 |
+
])
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 423 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 424 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 425 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
|
| 426 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 427 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 428 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 429 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 430 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 431 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 432 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 433 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 434 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 435 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
|
| 436 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 437 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 438 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 439 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 440 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 441 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 442 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 443 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 444 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
|
| 445 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 446 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 447 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 448 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 449 |
+
_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5,
|
| 450 |
+
permute_default_1,
|
| 451 |
+
permute_default_3
|
| 452 |
+
])
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 36 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 37 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 39 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 40 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 41 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 42 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
|
| 43 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 44 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 45 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 46 |
+
add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
|
| 47 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 48 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 49 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 50 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 51 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 52 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 53 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 54 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 55 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 56 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 57 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 58 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 59 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 60 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 61 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 62 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 63 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 64 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 65 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 66 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 67 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 68 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 69 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 70 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 71 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 72 |
+
where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default)
|
| 73 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
|
| 74 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 75 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 76 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 77 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 78 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 79 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 80 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 81 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 82 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 83 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 84 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 85 |
+
_sfdp_pattern_19_training = MultiOutputPattern([view_default_5,
|
| 86 |
+
view_default_9,
|
| 87 |
+
permute_default_4,
|
| 88 |
+
view_default_11,
|
| 89 |
+
None,
|
| 90 |
+
None,
|
| 91 |
+
None
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 96 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 97 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 98 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 99 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 100 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 101 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 102 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 103 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 104 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 105 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 106 |
+
add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
|
| 107 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 108 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 109 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 110 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 111 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 112 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 113 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 114 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 115 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 116 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 117 |
+
_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 121 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 122 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 123 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 124 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 125 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 126 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 127 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 128 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 129 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
|
| 130 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 131 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 132 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 133 |
+
add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
|
| 134 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 135 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 136 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 137 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 138 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 139 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 140 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default)
|
| 141 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 142 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 143 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 144 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 145 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 146 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 147 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 148 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 149 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 150 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 151 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 152 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 153 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 154 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored())
|
| 155 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 156 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 157 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
| 158 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 159 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 160 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 161 |
+
scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
|
| 162 |
+
where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default)
|
| 163 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
|
| 164 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 165 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 166 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 167 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 168 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 169 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 170 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 171 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 172 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 173 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 174 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 175 |
+
_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5,
|
| 176 |
+
view_default_9,
|
| 177 |
+
permute_default_4,
|
| 178 |
+
view_default_11,
|
| 179 |
+
None,
|
| 180 |
+
None,
|
| 181 |
+
None
|
| 182 |
+
])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 186 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 187 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 188 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 189 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 190 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 191 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 192 |
+
full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 193 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
|
| 194 |
+
full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 195 |
+
where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
|
| 196 |
+
add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
|
| 197 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 198 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 199 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 200 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 201 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 202 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 203 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 204 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 205 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 206 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 207 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 208 |
+
_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 34 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 36 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 38 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 39 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 40 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 41 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 42 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 43 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 44 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 45 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 46 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 47 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 48 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 49 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 50 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 51 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 52 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
| 53 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 54 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 55 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 56 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 57 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2)
|
| 58 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
|
| 59 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1)
|
| 60 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
|
| 61 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2)
|
| 62 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 63 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 64 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 65 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 66 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 67 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 68 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 69 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 70 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 71 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 72 |
+
_sfdp_pattern_2_training = MultiOutputPattern([view_default_5,
|
| 73 |
+
view_default_9,
|
| 74 |
+
permute_default_4,
|
| 75 |
+
view_default_11,
|
| 76 |
+
None
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 81 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 82 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 83 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 84 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 85 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 86 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 87 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 88 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 89 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 90 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 91 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 92 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 93 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 94 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 95 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 96 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 97 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 98 |
+
_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 102 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 103 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 104 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 105 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 106 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 107 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 108 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 109 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 110 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 111 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 112 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 113 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 114 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 115 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 116 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 117 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 118 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 119 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 120 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 121 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 122 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 123 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 124 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 125 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 126 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 127 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 128 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 129 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
|
| 130 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
|
| 131 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1)
|
| 132 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 133 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor'))
|
| 134 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2)
|
| 135 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 136 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 137 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 138 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 139 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 140 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 141 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 142 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 143 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 144 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 145 |
+
_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5,
|
| 146 |
+
view_default_9,
|
| 147 |
+
permute_default_4,
|
| 148 |
+
view_default_11,
|
| 149 |
+
None
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 154 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 155 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 156 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 157 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 158 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 159 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 160 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 161 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 162 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 163 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 164 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 165 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 166 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 167 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
| 168 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 169 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 170 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 171 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 172 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 173 |
+
_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 36 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 37 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 39 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 40 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 41 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 42 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 43 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 44 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 45 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 46 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 47 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 48 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 49 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 50 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 51 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 52 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 53 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 54 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 55 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 56 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 57 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 58 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 59 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 60 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 61 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 62 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 63 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 64 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 65 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 66 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 67 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
|
| 68 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 69 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 70 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 71 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 72 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 73 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 74 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 75 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 76 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 77 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 78 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 79 |
+
_sfdp_pattern_3_training = MultiOutputPattern([view_default_5,
|
| 80 |
+
view_default_9,
|
| 81 |
+
permute_default_4,
|
| 82 |
+
view_default_11,
|
| 83 |
+
None,
|
| 84 |
+
None
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 89 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 90 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 91 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 92 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 93 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 94 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 95 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
|
| 96 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 97 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 98 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 99 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 100 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 101 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 102 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 103 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 104 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 105 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 106 |
+
_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 110 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 111 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 112 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 113 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 114 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 115 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 116 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 117 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 118 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 119 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 120 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 121 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 122 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 123 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 124 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 125 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 126 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 127 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 128 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 129 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 130 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 131 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 132 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 133 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 134 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 135 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 136 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 137 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 138 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 139 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 140 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 141 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 142 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 143 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 144 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 145 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 146 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 147 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 148 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
|
| 149 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 150 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 151 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 152 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 153 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 154 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 155 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 156 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 157 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 158 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 159 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 160 |
+
_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5,
|
| 161 |
+
view_default_9,
|
| 162 |
+
permute_default_4,
|
| 163 |
+
view_default_11,
|
| 164 |
+
None,
|
| 165 |
+
None
|
| 166 |
+
])
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 170 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 171 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 172 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 173 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 174 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 175 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 176 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
|
| 177 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 178 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 179 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 180 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 181 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 182 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 183 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 184 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 185 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 186 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 187 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 188 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 189 |
+
_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 36 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 37 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 39 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 40 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 41 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 42 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 43 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 44 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 45 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 46 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 47 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 48 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
|
| 49 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
|
| 50 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
|
| 51 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 52 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 53 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 54 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 55 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 56 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor)
|
| 57 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 58 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 59 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 60 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 61 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 62 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 63 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
|
| 64 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2)
|
| 65 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
|
| 66 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
|
| 67 |
+
mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
|
| 68 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2)
|
| 69 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 70 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 71 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 72 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 73 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 74 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 75 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 76 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 77 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 78 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 79 |
+
_sfdp_pattern_4_training = MultiOutputPattern([view_default_5,
|
| 80 |
+
view_default_9,
|
| 81 |
+
permute_default_4,
|
| 82 |
+
view_default_11,
|
| 83 |
+
None,
|
| 84 |
+
None
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 89 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 90 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 91 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 92 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 93 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 94 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 95 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
|
| 96 |
+
amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
|
| 97 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
|
| 98 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 99 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 100 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 101 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 102 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 103 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 104 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 105 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 106 |
+
_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 110 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 111 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 112 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 113 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 114 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 115 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 116 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 117 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 118 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 119 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 120 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 121 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 122 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 123 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 124 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 125 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 126 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 127 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
|
| 128 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
|
| 129 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 130 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 131 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 132 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 133 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 134 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 135 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 136 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 137 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 138 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 139 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 140 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 141 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 142 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
|
| 143 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored())
|
| 144 |
+
mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 145 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
|
| 146 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
|
| 147 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 148 |
+
mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor'))
|
| 149 |
+
view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2)
|
| 150 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 151 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 152 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 153 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 154 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 155 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 156 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 157 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 158 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 159 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 160 |
+
_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5,
|
| 161 |
+
view_default_9,
|
| 162 |
+
permute_default_4,
|
| 163 |
+
view_default_11,
|
| 164 |
+
None,
|
| 165 |
+
None
|
| 166 |
+
])
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 170 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 171 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 172 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 173 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 174 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 175 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 176 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
|
| 177 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
|
| 178 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 179 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 180 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 181 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 182 |
+
div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 183 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
|
| 184 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 185 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 186 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 187 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 188 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 189 |
+
_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 34 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 36 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 38 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 39 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 40 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 41 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 42 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 43 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 44 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 45 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 46 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 47 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 48 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 49 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 50 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 51 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 52 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 53 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 54 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 55 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 56 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 57 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 58 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
| 59 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 60 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 61 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
| 62 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 63 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 64 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 65 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 66 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 67 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 68 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 69 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 70 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 71 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 72 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 73 |
+
_sfdp_pattern_5_training = MultiOutputPattern([view_default_5,
|
| 74 |
+
view_default_9,
|
| 75 |
+
permute_default_4,
|
| 76 |
+
view_default_11,
|
| 77 |
+
None
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 82 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 83 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 84 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 85 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 86 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 87 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 88 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 89 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 90 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 91 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 92 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 93 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 94 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 95 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 96 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 97 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 98 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 99 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 100 |
+
_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 104 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 105 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 106 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 107 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 108 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 109 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 110 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 111 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 112 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 113 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 114 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 115 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 116 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 117 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 118 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 119 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 120 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 121 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 122 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 123 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 124 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 125 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 126 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 127 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 128 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 129 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 130 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 131 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 132 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
|
| 133 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 134 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 135 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 136 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
|
| 137 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 138 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 139 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 140 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 141 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 142 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 143 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 144 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 145 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 146 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 147 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 148 |
+
_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5,
|
| 149 |
+
view_default_9,
|
| 150 |
+
permute_default_4,
|
| 151 |
+
view_default_11,
|
| 152 |
+
None
|
| 153 |
+
])
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 157 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 158 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 159 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 160 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 161 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 162 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 163 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 164 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 165 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 166 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 167 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 168 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 169 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 170 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 171 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 172 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 173 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 174 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 175 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 176 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 177 |
+
_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 36 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 37 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 39 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 40 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 41 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 42 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 43 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 44 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 45 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 46 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 47 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 48 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 49 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 50 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 54 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 55 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 56 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 57 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 58 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 59 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 60 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 61 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 62 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 63 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
|
| 64 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 65 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 66 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 67 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 68 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
| 69 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 70 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 71 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 72 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 73 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 74 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 75 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 76 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 77 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 78 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 79 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 80 |
+
_sfdp_pattern_6_training = MultiOutputPattern([view_default_5,
|
| 81 |
+
view_default_9,
|
| 82 |
+
permute_default_4,
|
| 83 |
+
view_default_11,
|
| 84 |
+
None,
|
| 85 |
+
None
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 90 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 91 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 92 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 93 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 94 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 95 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 96 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 97 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
|
| 98 |
+
amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
|
| 99 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
|
| 100 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 101 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 102 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 103 |
+
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
| 104 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 105 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 106 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 107 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 108 |
+
_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 112 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 113 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 114 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
| 115 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 116 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 117 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
| 118 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 119 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 120 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 121 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 122 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 123 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 124 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 125 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 126 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 127 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 128 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
|
| 129 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
|
| 130 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 131 |
+
expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
|
| 132 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 133 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 134 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
| 135 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 136 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 137 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
|
| 138 |
+
neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
|
| 139 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 140 |
+
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 141 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
| 142 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 143 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 144 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 145 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
|
| 146 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
|
| 147 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
|
| 148 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 149 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 150 |
+
convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 151 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored())
|
| 152 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 153 |
+
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 154 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
| 155 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 156 |
+
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 157 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
| 158 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 159 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 160 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 161 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
| 162 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 163 |
+
_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5,
|
| 164 |
+
view_default_9,
|
| 165 |
+
permute_default_4,
|
| 166 |
+
view_default_11,
|
| 167 |
+
None,
|
| 168 |
+
None
|
| 169 |
+
])
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
| 173 |
+
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
| 174 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 175 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 176 |
+
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
| 177 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 178 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 179 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 180 |
+
add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
|
| 181 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
|
| 182 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 183 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 184 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 185 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 186 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 187 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 188 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 189 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 190 |
+
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
| 191 |
+
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
| 192 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 193 |
+
_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 36 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 37 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 38 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 39 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 40 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 41 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 42 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 43 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 44 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 45 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 46 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 47 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 48 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 49 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 50 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 51 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 52 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 53 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 54 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 55 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 56 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 57 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 58 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 59 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 60 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 61 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 62 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 63 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 68 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 69 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 70 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 71 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 72 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 73 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 74 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 75 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 76 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
| 77 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 78 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 79 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 80 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 81 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 82 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 83 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 84 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 85 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 86 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 87 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 88 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 89 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 90 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 91 |
+
_sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
|
| 92 |
+
permute_default_6,
|
| 93 |
+
permute_default_9,
|
| 94 |
+
permute_default_11,
|
| 95 |
+
None
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 100 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 101 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 102 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 103 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 104 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 105 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 106 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 107 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 108 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 109 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 110 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 111 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 112 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 113 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 114 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 115 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 116 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 117 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 118 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 119 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 120 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 121 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 122 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 123 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 124 |
+
_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 128 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 129 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 130 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 131 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 132 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 133 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 134 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 135 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 136 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 137 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 138 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 139 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 140 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 141 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 142 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 143 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 144 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 145 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 146 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 147 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 148 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 149 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 150 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 151 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 152 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 153 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 154 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 155 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 156 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 157 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 158 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 159 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 160 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 161 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 162 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 163 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 164 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 165 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 166 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 167 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 168 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 169 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 170 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 171 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
|
| 172 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 173 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 174 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 175 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 176 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 177 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 178 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 179 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 180 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 181 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 182 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 183 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 184 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 185 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 186 |
+
_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5,
|
| 187 |
+
permute_default_6,
|
| 188 |
+
permute_default_9,
|
| 189 |
+
permute_default_11,
|
| 190 |
+
None
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 195 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 196 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 197 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 198 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 199 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 200 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 201 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 202 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 203 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 204 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 205 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 206 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 207 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 208 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 209 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 210 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 211 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 212 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 213 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 214 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 215 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 216 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 217 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 218 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 219 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 220 |
+
_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 34 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 35 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 36 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 37 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 38 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 39 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 40 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 41 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 42 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 43 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 44 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 45 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 46 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 47 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 48 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 49 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 50 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 51 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 52 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 53 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 54 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 55 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 56 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 57 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 58 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 59 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 60 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 61 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 62 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 63 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 64 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 65 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 66 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
| 67 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 68 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 69 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
|
| 70 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 71 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 72 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 73 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 74 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 75 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 76 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 77 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 78 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 79 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 80 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 81 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 82 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 83 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 84 |
+
_sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
|
| 85 |
+
permute_default_6,
|
| 86 |
+
permute_default_9,
|
| 87 |
+
permute_default_11
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 92 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 93 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 94 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 95 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 96 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 97 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 98 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 99 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 100 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 101 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 102 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
|
| 103 |
+
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
| 104 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
| 105 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 106 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 107 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 108 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 109 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 110 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 111 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 112 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 113 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 114 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 115 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 116 |
+
_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 120 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 121 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 122 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 123 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 124 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 125 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 126 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 127 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 128 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 129 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 130 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 131 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 132 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 133 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 134 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 135 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 136 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 137 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 138 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 139 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 140 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 141 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 142 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 143 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 144 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 145 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 146 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 147 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 148 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 149 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 150 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 151 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 152 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
|
| 153 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
| 154 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
|
| 155 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 156 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored())
|
| 157 |
+
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
| 158 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 159 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 160 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 161 |
+
permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
|
| 162 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 163 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 164 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 165 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 166 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 167 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 168 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 169 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 170 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 171 |
+
_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5,
|
| 172 |
+
permute_default_6,
|
| 173 |
+
permute_default_9,
|
| 174 |
+
permute_default_11
|
| 175 |
+
])
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 179 |
+
expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
|
| 180 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 181 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 182 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 183 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 184 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 185 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 186 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 187 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 188 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 189 |
+
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
|
| 190 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
|
| 191 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 192 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 193 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 194 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 195 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 196 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 197 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 198 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 199 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 200 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 201 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 202 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 203 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 204 |
+
_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 34 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 35 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 36 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 37 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 38 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 39 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 40 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 41 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 42 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 43 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 44 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 45 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 46 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 47 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 48 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 49 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 50 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 51 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 52 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 53 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 54 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 55 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 56 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 57 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 58 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 59 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 60 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 61 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 62 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 63 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 64 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 65 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 66 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 67 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
|
| 68 |
+
view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
|
| 69 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 70 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 71 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 72 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 73 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 74 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 75 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 76 |
+
view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
|
| 77 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 78 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 79 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 80 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 81 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 82 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 83 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 84 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 85 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 86 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 87 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 88 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 89 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 90 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 91 |
+
_sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
|
| 92 |
+
permute_default_6,
|
| 93 |
+
permute_default_9,
|
| 94 |
+
permute_default_11,
|
| 95 |
+
None
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 100 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 101 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 102 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 103 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 104 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 105 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 106 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 107 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 108 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 109 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 110 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
|
| 111 |
+
amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
|
| 112 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
|
| 113 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 114 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 115 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 116 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 117 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
|
| 118 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 119 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 120 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 121 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 122 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 123 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 124 |
+
_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
|
| 128 |
+
gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
|
| 129 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 130 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 131 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 132 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 133 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
|
| 134 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 135 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 136 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 137 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 138 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
|
| 139 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 140 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 141 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 142 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 143 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 144 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 145 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 146 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
| 147 |
+
mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
|
| 148 |
+
mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
|
| 149 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
|
| 150 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 151 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
| 152 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 153 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 154 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 155 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
|
| 156 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 157 |
+
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
| 158 |
+
neg_default = CallFunction(aten.neg.default, div_Tensor_1)
|
| 159 |
+
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
| 160 |
+
permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
| 161 |
+
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
|
| 162 |
+
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
| 163 |
+
convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
|
| 164 |
+
convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
|
| 165 |
+
mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
|
| 166 |
+
mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
|
| 167 |
+
mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
|
| 168 |
+
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
|
| 169 |
+
fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
|
| 170 |
+
convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
|
| 171 |
+
view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2)
|
| 172 |
+
permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
| 173 |
+
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
|
| 174 |
+
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
| 175 |
+
div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
|
| 176 |
+
permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
|
| 177 |
+
permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
|
| 178 |
+
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
|
| 179 |
+
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
| 180 |
+
permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
| 181 |
+
permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
|
| 182 |
+
permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
| 183 |
+
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
|
| 184 |
+
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
| 185 |
+
permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
|
| 186 |
+
_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5,
|
| 187 |
+
permute_default_6,
|
| 188 |
+
permute_default_9,
|
| 189 |
+
permute_default_11,
|
| 190 |
+
None
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
|
| 195 |
+
div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
|
| 196 |
+
expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
|
| 197 |
+
clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
|
| 198 |
+
view_default = CallFunction(aten.view.default, clone_default, Ignored())
|
| 199 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
| 200 |
+
permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
|
| 201 |
+
expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
|
| 202 |
+
clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
|
| 203 |
+
view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
|
| 204 |
+
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
| 205 |
+
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
| 206 |
+
convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
|
| 207 |
+
amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
|
| 208 |
+
sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
|
| 209 |
+
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
| 210 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
| 211 |
+
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
| 212 |
+
convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
|
| 213 |
+
expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
|
| 214 |
+
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
| 215 |
+
permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
|
| 216 |
+
expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
|
| 217 |
+
clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
|
| 218 |
+
view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
|
| 219 |
+
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
| 220 |
+
_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'))
|
| 34 |
+
mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta'))
|
| 35 |
+
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True)
|
| 36 |
+
view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored())
|
| 37 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored())
|
| 38 |
+
mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default)
|
| 39 |
+
mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha'))
|
| 40 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored())
|
| 41 |
+
mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1'))
|
| 42 |
+
mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha'))
|
| 43 |
+
addmm_pattern_training = MultiOutputPattern([addmm_default,
|
| 44 |
+
view_default,
|
| 45 |
+
mul_Scalar_1,
|
| 46 |
+
mul_Scalar_2,
|
| 47 |
+
None,
|
| 48 |
+
None
|
| 49 |
+
])
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
| 34 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored())
|
| 35 |
+
bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default)
|
| 36 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored())
|
| 37 |
+
bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1'))
|
| 38 |
+
bmm_pattern_training = MultiOutputPattern([bmm_default,
|
| 39 |
+
bmm_default_1,
|
| 40 |
+
bmm_default_2
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
# noqa: F401, E501
|
| 4 |
+
# This is an auto-generated file. Please do not modify it by hand.
|
| 5 |
+
# To re-generate, run:
|
| 6 |
+
# cd ~/pytorch && python torchgen/fuse/gen_patterns.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch._inductor
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
prims = torch.ops.prims
|
| 13 |
+
|
| 14 |
+
from torch._inductor.pattern_matcher import (
|
| 15 |
+
Arg,
|
| 16 |
+
CallFunction,
|
| 17 |
+
CallFunctionVarArgs,
|
| 18 |
+
CallMethod,
|
| 19 |
+
CallMethodVarArgs,
|
| 20 |
+
CallModule,
|
| 21 |
+
CallModuleVarArgs,
|
| 22 |
+
ExclusiveKeywordArg,
|
| 23 |
+
Ignored,
|
| 24 |
+
KeywordArg,
|
| 25 |
+
ListOf,
|
| 26 |
+
MultiOutputPattern,
|
| 27 |
+
PatternExpr,
|
| 28 |
+
RepeatedExpr,
|
| 29 |
+
_TargetArgsExpr,
|
| 30 |
+
_TargetExpr,
|
| 31 |
+
_TargetExprVarArgs,
|
| 32 |
+
)
|
| 33 |
+
mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
|
| 34 |
+
permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored())
|
| 35 |
+
mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default)
|
| 36 |
+
permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored())
|
| 37 |
+
mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1'))
|
| 38 |
+
mm_pattern_training = MultiOutputPattern([mm_default,
|
| 39 |
+
mm_default_1,
|
| 40 |
+
mm_default_2
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0)
|
.venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/Lib/site-packages/torch/_inductor/kernel/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm
|
.venv/Lib/site-packages/torch/_inductor/kernel/bmm.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .. import ir, lowering as L
|
| 7 |
+
from ..select_algorithm import (
|
| 8 |
+
autotune_select_algorithm,
|
| 9 |
+
ExternKernelChoice,
|
| 10 |
+
TritonTemplate,
|
| 11 |
+
)
|
| 12 |
+
from ..utils import (
|
| 13 |
+
ceildiv as cdiv,
|
| 14 |
+
use_aten_gemm_kernels,
|
| 15 |
+
use_cutlass_template,
|
| 16 |
+
use_triton_template,
|
| 17 |
+
)
|
| 18 |
+
from ..virtualized import V
|
| 19 |
+
from .mm import _is_static_problem
|
| 20 |
+
from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
aten = torch.ops.aten
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def bmm_grid(b, m, n, meta):
|
| 28 |
+
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
bmm_template = TritonTemplate(
|
| 32 |
+
name="bmm",
|
| 33 |
+
grid=bmm_grid,
|
| 34 |
+
source=r"""
|
| 35 |
+
{{def_kernel("A", "B")}}
|
| 36 |
+
M = {{size("A", -2)}}
|
| 37 |
+
N = {{size("B", -1)}}
|
| 38 |
+
K = {{size("A", -1)}}
|
| 39 |
+
|
| 40 |
+
stride_aq = {{stride("A", 0)}}
|
| 41 |
+
stride_am = {{stride("A", 1)}}
|
| 42 |
+
stride_ak = {{stride("A", 2)}}
|
| 43 |
+
|
| 44 |
+
stride_bq = {{stride("B", 0)}}
|
| 45 |
+
stride_bk = {{stride("B", 1)}}
|
| 46 |
+
stride_bn = {{stride("B", 2)}}
|
| 47 |
+
|
| 48 |
+
# based on triton.ops.matmul
|
| 49 |
+
pid = tl.program_id(0)
|
| 50 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 51 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 52 |
+
|
| 53 |
+
# re-order program ID for better L2 performance
|
| 54 |
+
width = GROUP_M * grid_n
|
| 55 |
+
group_id = pid // width
|
| 56 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 57 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 58 |
+
pid_n = (pid % width) // (group_size)
|
| 59 |
+
|
| 60 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 61 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 62 |
+
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
|
| 63 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 64 |
+
else:
|
| 65 |
+
ram = rm % M
|
| 66 |
+
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
|
| 67 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 68 |
+
else:
|
| 69 |
+
rbn = rn % N
|
| 70 |
+
|
| 71 |
+
rk = tl.arange(0, BLOCK_K)
|
| 72 |
+
|
| 73 |
+
idx_q = tl.program_id(1) # batch dimension for BMM
|
| 74 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
|
| 75 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
|
| 76 |
+
|
| 77 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 78 |
+
for k in range(K, 0, -BLOCK_K):
|
| 79 |
+
if EVEN_K:
|
| 80 |
+
a = tl.load(A)
|
| 81 |
+
b = tl.load(B)
|
| 82 |
+
else:
|
| 83 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 84 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 85 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 86 |
+
A += BLOCK_K * stride_ak
|
| 87 |
+
B += BLOCK_K * stride_bk
|
| 88 |
+
|
| 89 |
+
# rematerialize rm and rn to save registers
|
| 90 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 91 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 92 |
+
idx_q = tl.program_id(1) # batch dimension for BMM
|
| 93 |
+
idx_m = rm[:, None]
|
| 94 |
+
idx_n = rn[None, :]
|
| 95 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 96 |
+
|
| 97 |
+
# inductor generates a suffix
|
| 98 |
+
{{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
|
| 99 |
+
""",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
|
| 103 |
+
aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@L.register_lowering(aten.bmm)
|
| 107 |
+
def tuned_bmm(mat1, mat2, *, layout=None):
|
| 108 |
+
if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
|
| 109 |
+
# decompose to small ops when memory bound
|
| 110 |
+
if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
|
| 111 |
+
mat1 = L.unsqueeze(mat1, -1)
|
| 112 |
+
mat2 = L.unsqueeze(mat2, 1)
|
| 113 |
+
return L.sum_(L.mul(mat1, mat2), axis=2)
|
| 114 |
+
|
| 115 |
+
def is_valid_to_require_contiguous(t):
|
| 116 |
+
if not ir.is_storage_and_layout(t):
|
| 117 |
+
return True
|
| 118 |
+
_, layout = ir.as_storage_and_layout(t, freeze=False)
|
| 119 |
+
return isinstance(layout, ir.FlexibleLayout)
|
| 120 |
+
|
| 121 |
+
def is_preferred_layout_as_bmm_input(sizes, strides):
|
| 122 |
+
# contiguous on one of the last two dims
|
| 123 |
+
return (
|
| 124 |
+
strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
|
| 125 |
+
) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
|
| 126 |
+
|
| 127 |
+
# Make the input of bmm contiguous
|
| 128 |
+
# if it is not contiguous on either of the last two dims,
|
| 129 |
+
# because bmm cpu implementation would do contiguous() if not.
|
| 130 |
+
# This is to avoid additional copies in bmm.
|
| 131 |
+
def may_require_contiguous(t, meta_t):
|
| 132 |
+
sizes = meta_t.meta["val"].size()
|
| 133 |
+
strides = meta_t.meta["val"].stride()
|
| 134 |
+
if not is_preferred_layout_as_bmm_input(sizes, strides):
|
| 135 |
+
t = ir.ExternKernel.require_contiguous(t)
|
| 136 |
+
return t
|
| 137 |
+
|
| 138 |
+
if is_valid_to_require_contiguous(mat1):
|
| 139 |
+
meta_mat1 = V.graph.current_node.args[0]
|
| 140 |
+
mat1 = may_require_contiguous(mat1, meta_mat1)
|
| 141 |
+
if is_valid_to_require_contiguous(mat2):
|
| 142 |
+
meta_mat2 = V.graph.current_node.args[1]
|
| 143 |
+
mat2 = may_require_contiguous(mat2, meta_mat2)
|
| 144 |
+
|
| 145 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 146 |
+
|
| 147 |
+
# options to tune from
|
| 148 |
+
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 149 |
+
if use_triton_template(layout):
|
| 150 |
+
for config in mm_configs(m, n, k):
|
| 151 |
+
bmm_template.maybe_append_choice(
|
| 152 |
+
choices,
|
| 153 |
+
input_nodes=(mat1, mat2),
|
| 154 |
+
layout=layout,
|
| 155 |
+
**mm_options(config, m, n, k, layout),
|
| 156 |
+
)
|
| 157 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 158 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 159 |
+
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
|
| 160 |
+
|
| 161 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
| 162 |
+
|
| 163 |
+
if len(choices) == 0:
|
| 164 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 165 |
+
choices.append(aten_bmm.bind((mat1, mat2), layout))
|
| 166 |
+
|
| 167 |
+
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Don't register this since it is slower than decomposing it
|
| 171 |
+
# @L.register_lowering(aten.baddbmm)
|
| 172 |
+
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
| 173 |
+
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
|
| 174 |
+
|
| 175 |
+
# options to tune from
|
| 176 |
+
choices = (
|
| 177 |
+
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
|
| 178 |
+
if use_aten_gemm_kernels()
|
| 179 |
+
else []
|
| 180 |
+
)
|
| 181 |
+
if use_triton_template(layout):
|
| 182 |
+
for config in mm_configs(m, n, k):
|
| 183 |
+
bmm_template.maybe_append_choice(
|
| 184 |
+
choices,
|
| 185 |
+
input_nodes=(inp, mat1, mat2),
|
| 186 |
+
layout=layout,
|
| 187 |
+
**mm_options(config, m, n, k, layout),
|
| 188 |
+
prefix_args=1,
|
| 189 |
+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
|
.venv/Lib/site-packages/torch/_inductor/kernel/conv.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
import logging
|
| 7 |
+
from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .. import config, ir
|
| 12 |
+
from ..lowering import (
|
| 13 |
+
add_layout_constraint,
|
| 14 |
+
constrain_to_fx_strides,
|
| 15 |
+
lowerings as L,
|
| 16 |
+
register_lowering,
|
| 17 |
+
)
|
| 18 |
+
from ..select_algorithm import (
|
| 19 |
+
autotune_select_algorithm,
|
| 20 |
+
ExternKernelChoice,
|
| 21 |
+
TritonTemplate,
|
| 22 |
+
)
|
| 23 |
+
from ..utils import (
|
| 24 |
+
ceildiv,
|
| 25 |
+
is_ones,
|
| 26 |
+
is_zeros,
|
| 27 |
+
pad_listlike,
|
| 28 |
+
sympy_product,
|
| 29 |
+
use_triton_template,
|
| 30 |
+
)
|
| 31 |
+
from ..virtualized import V
|
| 32 |
+
from .mm_common import filtered_configs
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if TYPE_CHECKING:
|
| 36 |
+
from ..ir import TensorBox
|
| 37 |
+
|
| 38 |
+
log = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
aten = torch.ops.aten
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def conv2d_grid(n, c, h, w, meta):
|
| 45 |
+
return (
|
| 46 |
+
ceildiv(n * h * w, meta["BLOCK_M"]),
|
| 47 |
+
ceildiv(c, meta["BLOCK_N"]),
|
| 48 |
+
meta["GROUPS"],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def conv3d_grid(n, c, d, h, w, meta):
|
| 53 |
+
return (
|
| 54 |
+
ceildiv(n * d * h * w, meta["BLOCK_M"]),
|
| 55 |
+
ceildiv(c, meta["BLOCK_N"]),
|
| 56 |
+
meta["GROUPS"],
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 61 |
+
# will be utilised on the target platform
|
| 62 |
+
kernel_configs = [
|
| 63 |
+
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
|
| 64 |
+
{"config": (64, 256, 16, 2, 4), "cond": True},
|
| 65 |
+
{"config": (256, 64, 16, 2, 4), "cond": True},
|
| 66 |
+
{"config": (1024, 16, 16, 1, 8), "cond": True},
|
| 67 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 68 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 69 |
+
{"config": (64, 256, 32, 2, 8), "cond": True},
|
| 70 |
+
{"config": (256, 64, 32, 2, 8), "cond": True},
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# Create filtered list of configs based on conv
|
| 74 |
+
platform_configs = tuple(
|
| 75 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 76 |
+
for config in kernel_configs
|
| 77 |
+
if config["cond"]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 81 |
+
if torch.version.hip:
|
| 82 |
+
platform_configs = tuple(
|
| 83 |
+
(config[0], config[1], config[2], 1, config[4]) for config in platform_configs
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
conv_configs = functools.partial(
|
| 87 |
+
filtered_configs,
|
| 88 |
+
configs=platform_configs,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
LOOP_BODY_2D = """
|
| 92 |
+
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
| 93 |
+
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
| 94 |
+
idx_x_c = tl.arange(0, BLOCK_K) + k
|
| 95 |
+
|
| 96 |
+
x_ptrs = x_base + (
|
| 97 |
+
(idx_x_h * stride_xh)[:, None]
|
| 98 |
+
+ (idx_x_w * stride_xw)[:, None]
|
| 99 |
+
+ (idx_x_c * stride_xc)[None, :]
|
| 100 |
+
)
|
| 101 |
+
mask_x = (
|
| 102 |
+
(idx_n < BATCH)[:, None]
|
| 103 |
+
& (idx_x_h >= 0)[:, None]
|
| 104 |
+
& (idx_x_h < IN_H)[:, None]
|
| 105 |
+
& (idx_x_w >= 0)[:, None]
|
| 106 |
+
& (idx_x_w < IN_W)[:, None]
|
| 107 |
+
& (idx_x_c < GROUP_IN_C)[None, :]
|
| 108 |
+
)
|
| 109 |
+
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
| 110 |
+
|
| 111 |
+
w_ptrs = w_base + (
|
| 112 |
+
(idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
|
| 113 |
+
)
|
| 114 |
+
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
| 115 |
+
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
| 116 |
+
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
This is a relatively simple conv implementation that can likely be
|
| 121 |
+
improved. Many alternate conv versions can be found here:
|
| 122 |
+
https://github.com/pytorch/torchdynamo/pull/971
|
| 123 |
+
"""
|
| 124 |
+
conv2d_template = TritonTemplate(
|
| 125 |
+
name="convolution2d",
|
| 126 |
+
grid=conv2d_grid,
|
| 127 |
+
source=r"""
|
| 128 |
+
{{def_kernel("X", "W")}}
|
| 129 |
+
# Tensor dimensions
|
| 130 |
+
BATCH = {{size("X", 0)}}
|
| 131 |
+
IN_C = {{size("X", 1)}}
|
| 132 |
+
IN_H = {{size("X", 2)}}
|
| 133 |
+
IN_W = {{size("X", 3)}}
|
| 134 |
+
OUT_C = {{size(None, 1)}}
|
| 135 |
+
OUT_H = {{size(None, 2)}}
|
| 136 |
+
OUT_W = {{size(None, 3)}}
|
| 137 |
+
|
| 138 |
+
# Strides:
|
| 139 |
+
stride_xn = {{stride("X", 0)}}
|
| 140 |
+
stride_xc = {{stride("X", 1)}}
|
| 141 |
+
stride_xh = {{stride("X", 2)}}
|
| 142 |
+
stride_xw = {{stride("X", 3)}}
|
| 143 |
+
stride_wc_out = {{stride("W", 0)}}
|
| 144 |
+
stride_wc_in = {{stride("W", 1)}}
|
| 145 |
+
stride_wh = {{stride("W", 2)}}
|
| 146 |
+
stride_ww = {{stride("W", 3)}}
|
| 147 |
+
|
| 148 |
+
nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 149 |
+
idx_y_w = nhw % OUT_W
|
| 150 |
+
nh = nhw // OUT_W
|
| 151 |
+
idx_y_h = nh % OUT_H
|
| 152 |
+
idx_n = nh // OUT_H
|
| 153 |
+
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 154 |
+
|
| 155 |
+
{% if GROUPS == 1 %}
|
| 156 |
+
group = 0
|
| 157 |
+
GROUP_IN_C = IN_C
|
| 158 |
+
GROUP_OUT_C = OUT_C
|
| 159 |
+
{% else %}
|
| 160 |
+
group = tl.program_id(2)
|
| 161 |
+
GROUP_IN_C = IN_C // GROUPS
|
| 162 |
+
GROUP_OUT_C = OUT_C // GROUPS
|
| 163 |
+
{% endif %}
|
| 164 |
+
|
| 165 |
+
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
|
| 166 |
+
w_base = (
|
| 167 |
+
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 171 |
+
|
| 172 |
+
{% if UNROLL %}
|
| 173 |
+
{% for i in range(KERNEL_H) %}
|
| 174 |
+
{% for j in range(KERNEL_W) %}
|
| 175 |
+
i = {{i}}
|
| 176 |
+
j = {{j}}
|
| 177 |
+
for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 178 |
+
"""
|
| 179 |
+
+ LOOP_BODY_2D
|
| 180 |
+
+ """
|
| 181 |
+
{% endfor %}
|
| 182 |
+
{% endfor %}
|
| 183 |
+
{% else %}
|
| 184 |
+
# Could be simplified, but slightly slower:
|
| 185 |
+
# for i in range(KERNEL_H):
|
| 186 |
+
# for j in range(KERNEL_W):
|
| 187 |
+
# for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 188 |
+
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
|
| 189 |
+
for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
|
| 190 |
+
k = (ijk % BLOCK_K_COUNT) * BLOCK_K
|
| 191 |
+
ij = ijk // BLOCK_K_COUNT
|
| 192 |
+
i = ij // KERNEL_W
|
| 193 |
+
j = ij % KERNEL_W
|
| 194 |
+
"""
|
| 195 |
+
+ LOOP_BODY_2D
|
| 196 |
+
+ """
|
| 197 |
+
{% endif %}
|
| 198 |
+
|
| 199 |
+
mask = (
|
| 200 |
+
(idx_n < BATCH)[:, None]
|
| 201 |
+
& (idx_y_h < OUT_H)[:, None]
|
| 202 |
+
& (idx_y_w < OUT_W)[:, None]
|
| 203 |
+
& (idx_y_c < GROUP_OUT_C)[None, :]
|
| 204 |
+
)
|
| 205 |
+
idx_n = idx_n[:, None]
|
| 206 |
+
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
|
| 207 |
+
idx_h = idx_y_h[:, None]
|
| 208 |
+
idx_w = idx_y_w[:, None]
|
| 209 |
+
|
| 210 |
+
# inductor generates a suffix
|
| 211 |
+
{{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
|
| 212 |
+
""",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
LOOP_BODY_3D = """
|
| 216 |
+
idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D
|
| 217 |
+
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
|
| 218 |
+
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
|
| 219 |
+
idx_x_c = tl.arange(0, BLOCK_K) + k
|
| 220 |
+
|
| 221 |
+
x_ptrs = x_base + (
|
| 222 |
+
(idx_x_d * stride_xd)[:, None]
|
| 223 |
+
+ (idx_x_h * stride_xh)[:, None]
|
| 224 |
+
+ (idx_x_w * stride_xw)[:, None]
|
| 225 |
+
+ (idx_x_c * stride_xc)[None, :]
|
| 226 |
+
)
|
| 227 |
+
mask_x = (
|
| 228 |
+
(idx_n < BATCH)[:, None]
|
| 229 |
+
& (idx_x_d >= 0)[:, None]
|
| 230 |
+
& (idx_x_d < IN_D)[:, None]
|
| 231 |
+
& (idx_x_h >= 0)[:, None]
|
| 232 |
+
& (idx_x_h < IN_H)[:, None]
|
| 233 |
+
& (idx_x_w >= 0)[:, None]
|
| 234 |
+
& (idx_x_w < IN_W)[:, None]
|
| 235 |
+
& (idx_x_c < GROUP_IN_C)[None, :]
|
| 236 |
+
)
|
| 237 |
+
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
| 238 |
+
|
| 239 |
+
w_ptrs = w_base + (
|
| 240 |
+
(idx_x_c * stride_wc_in)[:, None] +
|
| 241 |
+
(d * stride_wd) + (i * stride_wh) + (j * stride_ww)
|
| 242 |
+
)
|
| 243 |
+
mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
|
| 244 |
+
matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
| 245 |
+
acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
conv3d_template = TritonTemplate(
|
| 249 |
+
name="convolution3d",
|
| 250 |
+
grid=conv3d_grid,
|
| 251 |
+
source=r"""
|
| 252 |
+
{{def_kernel("X", "W")}}
|
| 253 |
+
# Tensor dimensions
|
| 254 |
+
BATCH = {{size("X", 0)}}
|
| 255 |
+
IN_C = {{size("X", 1)}}
|
| 256 |
+
IN_D = {{size("X", 2)}}
|
| 257 |
+
IN_H = {{size("X", 3)}}
|
| 258 |
+
IN_W = {{size("X", 4)}}
|
| 259 |
+
OUT_C = {{size(None, 1)}}
|
| 260 |
+
OUT_D = {{size(None, 2)}}
|
| 261 |
+
OUT_H = {{size(None, 3)}}
|
| 262 |
+
OUT_W = {{size(None, 4)}}
|
| 263 |
+
|
| 264 |
+
# Strides:
|
| 265 |
+
stride_xn = {{stride("X", 0)}}
|
| 266 |
+
stride_xc = {{stride("X", 1)}}
|
| 267 |
+
stride_xd = {{stride("X", 2)}}
|
| 268 |
+
stride_xh = {{stride("X", 3)}}
|
| 269 |
+
stride_xw = {{stride("X", 4)}}
|
| 270 |
+
stride_wc_out = {{stride("W", 0)}}
|
| 271 |
+
stride_wc_in = {{stride("W", 1)}}
|
| 272 |
+
stride_wd = {{stride("W", 2)}}
|
| 273 |
+
stride_wh = {{stride("W", 3)}}
|
| 274 |
+
stride_ww = {{stride("W", 4)}}
|
| 275 |
+
|
| 276 |
+
ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 277 |
+
idx_y_w = ndhw % OUT_W
|
| 278 |
+
ndh = ndhw // OUT_W
|
| 279 |
+
idx_y_h = ndh % OUT_H
|
| 280 |
+
nd = ndh // OUT_H
|
| 281 |
+
idx_y_d = nd % OUT_D
|
| 282 |
+
idx_n = nd // OUT_D
|
| 283 |
+
idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 284 |
+
|
| 285 |
+
{% if GROUPS == 1 %}
|
| 286 |
+
group = 0
|
| 287 |
+
GROUP_IN_C = IN_C
|
| 288 |
+
GROUP_OUT_C = OUT_C
|
| 289 |
+
{% else %}
|
| 290 |
+
group = tl.program_id(2)
|
| 291 |
+
GROUP_IN_C = IN_C // GROUPS
|
| 292 |
+
GROUP_OUT_C = OUT_C // GROUPS
|
| 293 |
+
{% endif %}
|
| 294 |
+
|
| 295 |
+
x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
|
| 296 |
+
w_base = (
|
| 297 |
+
W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
| 301 |
+
|
| 302 |
+
{% if UNROLL %}
|
| 303 |
+
{% for d in range(KERNEL_D) %}
|
| 304 |
+
{% for i in range(KERNEL_H) %}
|
| 305 |
+
{% for j in range(KERNEL_W) %}
|
| 306 |
+
d = {{d}}
|
| 307 |
+
i = {{i}}
|
| 308 |
+
j = {{j}}
|
| 309 |
+
for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 310 |
+
"""
|
| 311 |
+
+ LOOP_BODY_3D
|
| 312 |
+
+ """
|
| 313 |
+
{% endfor %}
|
| 314 |
+
{% endfor %}
|
| 315 |
+
{% endfor %}
|
| 316 |
+
{% else %}
|
| 317 |
+
# Could be simplified, but slightly slower:
|
| 318 |
+
# for d in range(KERNEL_D):
|
| 319 |
+
# for i in range(KERNEL_H):
|
| 320 |
+
# for j in range(KERNEL_W):
|
| 321 |
+
# for k in range(0, GROUP_IN_C, BLOCK_K):
|
| 322 |
+
BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
|
| 323 |
+
for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
|
| 324 |
+
k = (dijk % BLOCK_K_COUNT) * BLOCK_K
|
| 325 |
+
dij = dijk // BLOCK_K_COUNT
|
| 326 |
+
j = dij % KERNEL_W
|
| 327 |
+
di = dij // KERNEL_W
|
| 328 |
+
i = di % KERNEL_H
|
| 329 |
+
d = di // KERNEL_H
|
| 330 |
+
"""
|
| 331 |
+
+ LOOP_BODY_3D
|
| 332 |
+
+ """
|
| 333 |
+
{% endif %}
|
| 334 |
+
|
| 335 |
+
mask = (
|
| 336 |
+
(idx_n < BATCH)[:, None]
|
| 337 |
+
& (idx_y_d < OUT_D)[:, None]
|
| 338 |
+
& (idx_y_h < OUT_H)[:, None]
|
| 339 |
+
& (idx_y_w < OUT_W)[:, None]
|
| 340 |
+
& (idx_y_c < GROUP_OUT_C)[None, :]
|
| 341 |
+
)
|
| 342 |
+
idx_n = idx_n[:, None]
|
| 343 |
+
idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
|
| 344 |
+
idx_d = idx_y_d[:, None]
|
| 345 |
+
idx_h = idx_y_h[:, None]
|
| 346 |
+
idx_w = idx_y_w[:, None]
|
| 347 |
+
|
| 348 |
+
# inductor generates a suffix
|
| 349 |
+
{{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}}
|
| 350 |
+
""",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
aten_convolution = ExternKernelChoice(
|
| 354 |
+
torch.convolution,
|
| 355 |
+
"at::convolution",
|
| 356 |
+
has_out_variant=False,
|
| 357 |
+
op_overload=aten.convolution.default,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def conv1x1_via_mm(x, w, *, out):
|
| 362 |
+
w = torch.squeeze(torch.squeeze(w, -1), -1)
|
| 363 |
+
return torch.matmul(
|
| 364 |
+
x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class ConvLayoutParams(TypedDict):
|
| 372 |
+
stride: tuple[int, ...]
|
| 373 |
+
padding: tuple[int, ...]
|
| 374 |
+
dilation: tuple[int, ...]
|
| 375 |
+
transposed: bool
|
| 376 |
+
output_padding: tuple[int, ...]
|
| 377 |
+
groups: int
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def conv_layout(
|
| 381 |
+
x: TensorBox,
|
| 382 |
+
weight: TensorBox,
|
| 383 |
+
bias: Optional[TensorBox],
|
| 384 |
+
stride: Sequence[int],
|
| 385 |
+
padding: tuple[int, ...],
|
| 386 |
+
dilation: tuple[int, ...],
|
| 387 |
+
transposed: bool,
|
| 388 |
+
output_padding: tuple[int, ...],
|
| 389 |
+
groups: int,
|
| 390 |
+
) -> ir.Layout:
|
| 391 |
+
"""Determine output layout for a convolution"""
|
| 392 |
+
with V.graph.fake_mode:
|
| 393 |
+
output = torch.ops.aten.convolution(
|
| 394 |
+
ir.ir_node_to_tensor(x, guard_shape=True),
|
| 395 |
+
ir.ir_node_to_tensor(weight, guard_shape=True),
|
| 396 |
+
ir.ir_node_to_tensor(bias, guard_shape=True),
|
| 397 |
+
V.graph.sizevars.size_hints(stride), # type: ignore[arg-type]
|
| 398 |
+
V.graph.sizevars.size_hints(padding), # type: ignore[arg-type]
|
| 399 |
+
V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type]
|
| 400 |
+
transposed,
|
| 401 |
+
V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type]
|
| 402 |
+
groups,
|
| 403 |
+
)
|
| 404 |
+
sizes = ir.convert_shape_to_inductor(output.size())
|
| 405 |
+
stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
|
| 406 |
+
|
| 407 |
+
return ir.FixedLayout(
|
| 408 |
+
x.get_device(),
|
| 409 |
+
x.get_dtype(),
|
| 410 |
+
sizes,
|
| 411 |
+
stride,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def channels_last_order(rank):
|
| 416 |
+
order = list(reversed(range(rank)))
|
| 417 |
+
order.insert(1, order.pop(-1))
|
| 418 |
+
return order
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def convert_1x1_conv_to_mm(x, weight, bias):
|
| 422 |
+
# special case for 1x1 convolution, which is actually just a matmul
|
| 423 |
+
rank = len(weight.get_size())
|
| 424 |
+
for _ in range(rank - 2):
|
| 425 |
+
weight = L[aten.squeeze](weight, dim=-1)
|
| 426 |
+
weight = L[aten.permute](weight, [1, 0])
|
| 427 |
+
|
| 428 |
+
x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
|
| 429 |
+
x_permute = list(range(rank))
|
| 430 |
+
x_permute.append(x_permute.pop(1))
|
| 431 |
+
x = L[aten.permute](x, x_permute)
|
| 432 |
+
*sizes, in_chan = x.get_size()
|
| 433 |
+
x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
|
| 434 |
+
if bias is None:
|
| 435 |
+
result = L[aten.mm](x, weight)
|
| 436 |
+
else:
|
| 437 |
+
result = L[aten.addmm](bias, x, weight)
|
| 438 |
+
result = L[aten.reshape](result, [*sizes, -1])
|
| 439 |
+
result_permute = list(range(rank))
|
| 440 |
+
result_permute.insert(1, result_permute.pop(-1))
|
| 441 |
+
return L[aten.permute](result, result_permute)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@register_lowering(aten.convolution)
|
| 445 |
+
def convolution(
|
| 446 |
+
x: TensorBox,
|
| 447 |
+
weight: TensorBox,
|
| 448 |
+
bias: TensorBox,
|
| 449 |
+
stride: List[int],
|
| 450 |
+
padding: List[int],
|
| 451 |
+
dilation: List[int],
|
| 452 |
+
transposed: bool,
|
| 453 |
+
output_padding: List[int],
|
| 454 |
+
groups: int,
|
| 455 |
+
):
|
| 456 |
+
stride = tuple(stride)
|
| 457 |
+
padding = tuple(padding)
|
| 458 |
+
dilation = tuple(dilation)
|
| 459 |
+
output_padding = tuple(output_padding)
|
| 460 |
+
if not isinstance(groups, int):
|
| 461 |
+
groups = V.graph.sizevars.evaluate_static_shape(groups)
|
| 462 |
+
assert isinstance(groups, int)
|
| 463 |
+
|
| 464 |
+
# Need use hint for triton template since the template does not
|
| 465 |
+
# work with a dynamic shape.
|
| 466 |
+
#
|
| 467 |
+
# No need to evaluate_static_shape for dilation and output_padding
|
| 468 |
+
# since the template is only used when dilation is 1 and output_padding
|
| 469 |
+
# is 0.
|
| 470 |
+
stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
|
| 471 |
+
padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
|
| 472 |
+
|
| 473 |
+
kwargs: ConvLayoutParams = {
|
| 474 |
+
"stride": stride,
|
| 475 |
+
"padding": padding,
|
| 476 |
+
"dilation": dilation,
|
| 477 |
+
"transposed": transposed,
|
| 478 |
+
"output_padding": output_padding,
|
| 479 |
+
"groups": groups,
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
if len(x.get_size()) == len(weight.get_size()) - 1:
|
| 483 |
+
# add batch dimension to simplify rest of function
|
| 484 |
+
return L[aten.squeeze](
|
| 485 |
+
convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
|
| 486 |
+
dim=0,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
|
| 490 |
+
weight.get_size()
|
| 491 |
+
)
|
| 492 |
+
ndim = len(kernel_shape)
|
| 493 |
+
stride = pad_listlike(stride, ndim)
|
| 494 |
+
padding = pad_listlike(padding, ndim)
|
| 495 |
+
dilation = pad_listlike(dilation, ndim)
|
| 496 |
+
output_padding = pad_listlike(output_padding, ndim)
|
| 497 |
+
|
| 498 |
+
def channels_last_conv():
|
| 499 |
+
if V.graph.layout_opt and ndim == 2:
|
| 500 |
+
return True
|
| 501 |
+
|
| 502 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 503 |
+
req_stride_order = ir.get_stride_order(
|
| 504 |
+
V.graph.sizevars.size_hints(layout.stride)
|
| 505 |
+
)
|
| 506 |
+
return req_stride_order == ir.NHWC_STRIDE_ORDER
|
| 507 |
+
|
| 508 |
+
autotuning_gemm = config.max_autotune or config.max_autotune_gemm
|
| 509 |
+
|
| 510 |
+
if (
|
| 511 |
+
(config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
|
| 512 |
+
and is_ones(kernel_shape)
|
| 513 |
+
and is_ones(stride)
|
| 514 |
+
and is_zeros(padding)
|
| 515 |
+
and is_ones(dilation)
|
| 516 |
+
and not transposed
|
| 517 |
+
and is_zeros(output_padding)
|
| 518 |
+
and groups == 1
|
| 519 |
+
and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
|
| 520 |
+
):
|
| 521 |
+
return convert_1x1_conv_to_mm(x, weight, bias)
|
| 522 |
+
|
| 523 |
+
if bias is not None and ir.get_device_type(x) != "cpu":
|
| 524 |
+
# peel off the bias, cudnn is slower with it
|
| 525 |
+
result = convolution(x, weight, None, **kwargs)
|
| 526 |
+
return L[aten.add](
|
| 527 |
+
result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
x.realize()
|
| 531 |
+
weight.realize()
|
| 532 |
+
|
| 533 |
+
# ndim can be 1 for convolution in models such as demucs
|
| 534 |
+
# TODO: check if it's beneficial to convert Conv1d to Conv2d and then
|
| 535 |
+
# apply channels last.
|
| 536 |
+
if V.graph.layout_opt and ndim == 2:
|
| 537 |
+
V.graph.num_channels_last_conv += 1
|
| 538 |
+
x = ir.ExternKernel.require_channels_last(x)
|
| 539 |
+
# TODO maybe we can convert weights to channels last just once before
|
| 540 |
+
# running the model.
|
| 541 |
+
weight = ir.ExternKernel.require_channels_last(weight)
|
| 542 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 543 |
+
else:
|
| 544 |
+
layout = conv_layout(x, weight, None, **kwargs)
|
| 545 |
+
req_stride_order = ir.get_stride_order(
|
| 546 |
+
V.graph.sizevars.size_hints(layout.stride)
|
| 547 |
+
)
|
| 548 |
+
x = ir.ExternKernel.require_stride_order(x, req_stride_order)
|
| 549 |
+
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
|
| 550 |
+
|
| 551 |
+
ordered_kwargs_for_cpp_kernel = [
|
| 552 |
+
"stride",
|
| 553 |
+
"padding",
|
| 554 |
+
"dilation",
|
| 555 |
+
"transposed",
|
| 556 |
+
"output_padding",
|
| 557 |
+
"groups",
|
| 558 |
+
]
|
| 559 |
+
if bias is None:
|
| 560 |
+
args = [x, weight]
|
| 561 |
+
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
|
| 562 |
+
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
|
| 563 |
+
else:
|
| 564 |
+
args = [x, weight, bias]
|
| 565 |
+
bias.realize()
|
| 566 |
+
bias.freeze_layout()
|
| 567 |
+
V.graph.sizevars.evaluate_static_shapes(bias.get_size())
|
| 568 |
+
|
| 569 |
+
choices = []
|
| 570 |
+
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
|
| 571 |
+
choices = [
|
| 572 |
+
aten_convolution.bind(
|
| 573 |
+
args,
|
| 574 |
+
layout,
|
| 575 |
+
ordered_kwargs_for_cpp_kernel,
|
| 576 |
+
**kwargs,
|
| 577 |
+
)
|
| 578 |
+
]
|
| 579 |
+
|
| 580 |
+
if (
|
| 581 |
+
torch._inductor.utils._use_conv_autotune_backend("TRITON")
|
| 582 |
+
and use_triton_template(layout)
|
| 583 |
+
# templates only support these:
|
| 584 |
+
and is_ones(dilation)
|
| 585 |
+
and not transposed
|
| 586 |
+
and is_zeros(output_padding)
|
| 587 |
+
# there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
|
| 588 |
+
and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
|
| 589 |
+
):
|
| 590 |
+
if (
|
| 591 |
+
is_ones(kernel_shape)
|
| 592 |
+
and is_ones(stride)
|
| 593 |
+
and is_zeros(padding)
|
| 594 |
+
and groups == 1
|
| 595 |
+
):
|
| 596 |
+
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
| 597 |
+
|
| 598 |
+
for cfg in conv_configs(
|
| 599 |
+
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
| 600 |
+
out_chan,
|
| 601 |
+
in_chan,
|
| 602 |
+
):
|
| 603 |
+
if ndim == 2:
|
| 604 |
+
conv2d_template.maybe_append_choice(
|
| 605 |
+
choices,
|
| 606 |
+
input_nodes=(x, weight),
|
| 607 |
+
layout=layout,
|
| 608 |
+
KERNEL_H=kernel_shape[0],
|
| 609 |
+
KERNEL_W=kernel_shape[1],
|
| 610 |
+
STRIDE_H=stride[0],
|
| 611 |
+
STRIDE_W=stride[1],
|
| 612 |
+
PADDING_H=padding[0],
|
| 613 |
+
PADDING_W=padding[1],
|
| 614 |
+
GROUPS=groups,
|
| 615 |
+
# TODO(jansel): try unroll for bigger kernels once fixed:
|
| 616 |
+
# https://github.com/openai/triton/issues/1254
|
| 617 |
+
UNROLL=is_ones(kernel_shape),
|
| 618 |
+
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
| 619 |
+
num_stages=cfg.num_stages,
|
| 620 |
+
num_warps=cfg.num_warps,
|
| 621 |
+
**cfg.kwargs,
|
| 622 |
+
)
|
| 623 |
+
elif ndim == 3:
|
| 624 |
+
conv3d_template.maybe_append_choice(
|
| 625 |
+
choices,
|
| 626 |
+
input_nodes=(x, weight),
|
| 627 |
+
layout=layout,
|
| 628 |
+
KERNEL_D=kernel_shape[0],
|
| 629 |
+
KERNEL_H=kernel_shape[1],
|
| 630 |
+
KERNEL_W=kernel_shape[2],
|
| 631 |
+
STRIDE_D=stride[0],
|
| 632 |
+
STRIDE_H=stride[1],
|
| 633 |
+
STRIDE_W=stride[2],
|
| 634 |
+
PADDING_D=padding[0],
|
| 635 |
+
PADDING_H=padding[1],
|
| 636 |
+
PADDING_W=padding[2],
|
| 637 |
+
GROUPS=groups,
|
| 638 |
+
# TODO(jansel): try unroll for bigger kernels once fixed:
|
| 639 |
+
# https://github.com/openai/triton/issues/1254
|
| 640 |
+
UNROLL=is_ones(kernel_shape),
|
| 641 |
+
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
| 642 |
+
num_stages=cfg.num_stages,
|
| 643 |
+
num_warps=cfg.num_warps,
|
| 644 |
+
**cfg.kwargs,
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
return autotune_select_algorithm("convolution", choices, args, layout)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
@register_lowering(aten._convolution)
|
| 651 |
+
def _convolution(
|
| 652 |
+
x,
|
| 653 |
+
weight,
|
| 654 |
+
bias,
|
| 655 |
+
stride,
|
| 656 |
+
padding,
|
| 657 |
+
dilation,
|
| 658 |
+
transposed,
|
| 659 |
+
output_padding,
|
| 660 |
+
groups,
|
| 661 |
+
benchmark,
|
| 662 |
+
deterministic,
|
| 663 |
+
cudnn_enabled,
|
| 664 |
+
allow_tf32,
|
| 665 |
+
):
|
| 666 |
+
return convolution(
|
| 667 |
+
x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
|
| 672 |
+
assert fx_node.target == torch.ops.aten.convolution.default
|
| 673 |
+
if V.graph.layout_opt:
|
| 674 |
+
return args, kwargs
|
| 675 |
+
else:
|
| 676 |
+
return constrain_to_fx_strides(fx_node, *args, **kwargs)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
|
.venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py
ADDED
|
@@ -0,0 +1,1843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
""" Triton Implementation of the flex_attention Kernel"""
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
from typing import Any, List, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch._inductor.virtualized import V
|
| 12 |
+
from torch.utils._pytree import tree_map
|
| 13 |
+
|
| 14 |
+
from .. import config
|
| 15 |
+
from ..ir import (
|
| 16 |
+
ComputedBuffer,
|
| 17 |
+
ExternKernel,
|
| 18 |
+
FixedLayout,
|
| 19 |
+
FlexibleLayout,
|
| 20 |
+
get_stride_order,
|
| 21 |
+
InputBuffer,
|
| 22 |
+
IRNode,
|
| 23 |
+
StorageBox,
|
| 24 |
+
stride_order2fill_order,
|
| 25 |
+
Subgraph,
|
| 26 |
+
TensorBox,
|
| 27 |
+
)
|
| 28 |
+
from ..lowering import empty, empty_strided, lowerings, register_lowering
|
| 29 |
+
from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
aten = torch.ops.aten
|
| 34 |
+
Expr = sympy.Expr
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def construct_strides(
|
| 38 |
+
sizes: Sequence[int],
|
| 39 |
+
fill_order: Sequence[int],
|
| 40 |
+
) -> Sequence[int]:
|
| 41 |
+
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
| 42 |
+
# Initialize strides
|
| 43 |
+
assert len(sizes) == len(
|
| 44 |
+
fill_order
|
| 45 |
+
), "Length of sizes must match the length of the fill order"
|
| 46 |
+
strides = [0] * len(sizes)
|
| 47 |
+
|
| 48 |
+
# Start with stride 1 for the innermost dimension
|
| 49 |
+
current_stride = 1
|
| 50 |
+
|
| 51 |
+
# Iterate through the fill order populating strides
|
| 52 |
+
for dim in fill_order:
|
| 53 |
+
strides[dim] = current_stride
|
| 54 |
+
current_stride *= sizes[dim]
|
| 55 |
+
|
| 56 |
+
return strides
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
|
| 60 |
+
"""How is this kernel parallelized?
|
| 61 |
+
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
|
| 62 |
+
Each block is responsible for iterating over blocks of keys and values calculating
|
| 63 |
+
the final attention output.
|
| 64 |
+
"""
|
| 65 |
+
import triton
|
| 66 |
+
|
| 67 |
+
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def create_placeholder(
|
| 71 |
+
name: str, dtype: torch.dtype, device: torch.device
|
| 72 |
+
) -> TensorBox:
|
| 73 |
+
"""Creates a placeholder input buffers for producing subgraph_output."""
|
| 74 |
+
input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
|
| 75 |
+
return TensorBox.create(input_buffer)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def maybe_realize(args: List[Optional[IRNode]]):
|
| 79 |
+
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
|
| 80 |
+
return tree_map(lambda x: realize_inputs(x) if x is not None else None, args)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_float32_precision():
|
| 84 |
+
if torch.get_float32_matmul_precision() == "highest" or torch.version.hip:
|
| 85 |
+
return "'ieee'"
|
| 86 |
+
else:
|
| 87 |
+
return "'tf32'"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_subgraph_buffer(
|
| 91 |
+
args: List[TensorBox],
|
| 92 |
+
subgraph: Subgraph,
|
| 93 |
+
):
|
| 94 |
+
"""This function's goal is to take in the required args and produce the subgraph buffer
|
| 95 |
+
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
|
| 99 |
+
subgraph: The Subgraph ir for which to produce the output node
|
| 100 |
+
"""
|
| 101 |
+
cnt = 0
|
| 102 |
+
env = {}
|
| 103 |
+
for node in subgraph.graph_module.graph.nodes:
|
| 104 |
+
# There are two classes of placeholder inpts that we need
|
| 105 |
+
# to handle differently. For the first n_scalar_inps inputs
|
| 106 |
+
# we expect that these placeholders were generated by the make_fx call
|
| 107 |
+
# in the flex Attention HOP. So we need to create a new placeholder
|
| 108 |
+
# TensorBox for each of these inputs. For the rest of the inputs we
|
| 109 |
+
# expect that these are lifted inputs that fill up the '*other_buffers'
|
| 110 |
+
# tuple and already have corresponding TensorBoxes passed in as args.
|
| 111 |
+
if node.op == "placeholder":
|
| 112 |
+
env[node] = args[cnt]
|
| 113 |
+
cnt += 1
|
| 114 |
+
elif node.op == "call_function":
|
| 115 |
+
# For call_function we use the default lowerings and pass in the
|
| 116 |
+
# already created TensorBoxes as args
|
| 117 |
+
|
| 118 |
+
args, kwargs = tree_map(
|
| 119 |
+
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
|
| 120 |
+
)
|
| 121 |
+
env[node] = lowerings[node.target](*args, **kwargs)
|
| 122 |
+
elif node.op == "output":
|
| 123 |
+
|
| 124 |
+
def convert_output_node_to_buffer(output):
|
| 125 |
+
if output is None:
|
| 126 |
+
return None
|
| 127 |
+
output_node = output
|
| 128 |
+
output_buffer = env[output_node]
|
| 129 |
+
assert isinstance(output_buffer, TensorBox), (
|
| 130 |
+
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
|
| 131 |
+
type(output_buffer),
|
| 132 |
+
)
|
| 133 |
+
assert isinstance(output_buffer.data, StorageBox), (
|
| 134 |
+
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
|
| 135 |
+
type(output_buffer),
|
| 136 |
+
)
|
| 137 |
+
subgraph_buffer = ComputedBuffer(
|
| 138 |
+
name=None,
|
| 139 |
+
layout=FlexibleLayout(
|
| 140 |
+
device=output_buffer.data.get_device(),
|
| 141 |
+
dtype=output_buffer.data.get_dtype(),
|
| 142 |
+
size=output_buffer.data.get_size(),
|
| 143 |
+
),
|
| 144 |
+
data=output_buffer.data.data, # type: ignore[arg-type]
|
| 145 |
+
)
|
| 146 |
+
return subgraph_buffer
|
| 147 |
+
|
| 148 |
+
# node.args[0] is either a single element or a list of elements
|
| 149 |
+
# representing all outputs of the function.
|
| 150 |
+
return tree_map(convert_output_node_to_buffer, node.args[0])
|
| 151 |
+
|
| 152 |
+
raise ValueError("FlexAttention was passed a subgraph with no output node!")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
|
| 156 |
+
compute_next_offset_func = r"""
|
| 157 |
+
@triton.jit
|
| 158 |
+
def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
|
| 159 |
+
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
| 160 |
+
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
| 161 |
+
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
| 162 |
+
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
| 163 |
+
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
| 164 |
+
|
| 165 |
+
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
| 166 |
+
return offset
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
compute_flex_attention = r"""
|
| 170 |
+
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
| 171 |
+
# Sub notation for this kernel:
|
| 172 |
+
#
|
| 173 |
+
# Q: Query, K: Key, V: Value
|
| 174 |
+
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
| 175 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 176 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 177 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
| 178 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 179 |
+
#
|
| 180 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 181 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 182 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 183 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 184 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 185 |
+
#
|
| 186 |
+
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
| 187 |
+
#
|
| 188 |
+
# (Modifiable) Performance tuning options
|
| 189 |
+
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
| 190 |
+
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
| 191 |
+
|
| 192 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 193 |
+
# or involve a numerics vs. perf tradeoff
|
| 194 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 195 |
+
# about 20% more numerical error, but slightly faster.
|
| 196 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 197 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 198 |
+
|
| 199 |
+
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
| 200 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 201 |
+
|
| 202 |
+
# Define strides of inputs
|
| 203 |
+
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
|
| 204 |
+
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
| 205 |
+
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
| 206 |
+
|
| 207 |
+
Z = {{size("Q", 0)}}
|
| 208 |
+
HQ = {{size("Q", 1)}}
|
| 209 |
+
Q_LEN = {{size("Q", 2)}}
|
| 210 |
+
KV_LEN = {{size("K", 2)}}
|
| 211 |
+
|
| 212 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 213 |
+
|
| 214 |
+
q_start = tl.program_id(0)
|
| 215 |
+
off_z = tl.program_id(1) // HQ
|
| 216 |
+
off_hq = tl.program_id(1) % HQ
|
| 217 |
+
off_hkv = off_hq // GQA_SHARED_HEADS
|
| 218 |
+
off_g = off_hq % GQA_SHARED_HEADS
|
| 219 |
+
|
| 220 |
+
q_offset = off_z * stride_qz + off_hq * stride_qh
|
| 221 |
+
k_offset = off_z * stride_kz + off_hkv * stride_kh
|
| 222 |
+
v_offset = off_z * stride_vz + off_hkv * stride_vh
|
| 223 |
+
|
| 224 |
+
Q = Q + q_offset
|
| 225 |
+
K = K + k_offset
|
| 226 |
+
V = V + v_offset
|
| 227 |
+
|
| 228 |
+
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
| 229 |
+
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
| 230 |
+
|
| 231 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 232 |
+
sparse_idx_hq = off_hq % SPARSE_HQ
|
| 233 |
+
|
| 234 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
| 235 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 236 |
+
|
| 237 |
+
SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
|
| 238 |
+
SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 239 |
+
|
| 240 |
+
# initialize pointer to m and l
|
| 241 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 242 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 243 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
|
| 244 |
+
|
| 245 |
+
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 246 |
+
|
| 247 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 248 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
| 249 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
|
| 250 |
+
sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
|
| 251 |
+
|
| 252 |
+
Q_block_ptr = tl.make_block_ptr(
|
| 253 |
+
base=Q,
|
| 254 |
+
shape=(Q_LEN, QK_HEAD_DIM),
|
| 255 |
+
strides=(stride_qm, stride_qk),
|
| 256 |
+
offsets=(q_start * BLOCK_M, 0),
|
| 257 |
+
block_shape=(BLOCK_M, QK_HEAD_DIM),
|
| 258 |
+
order=(1, 0)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# load q: it stays in SRAM throughout the inner loop.
|
| 262 |
+
if IS_DIVISIBLE:
|
| 263 |
+
q = tl.load(Q_block_ptr)
|
| 264 |
+
else:
|
| 265 |
+
# boundary check is not free, so we only do it when necessary.
|
| 266 |
+
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero")
|
| 267 |
+
|
| 268 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 269 |
+
# We don't know anything "special" about these blocks, so we need to apply
|
| 270 |
+
# both score_mod and mask_mod to it
|
| 271 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 272 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 273 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 274 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 275 |
+
|
| 276 |
+
K_block_ptr = tl.make_block_ptr(
|
| 277 |
+
base=K,
|
| 278 |
+
shape=(QK_HEAD_DIM, KV_LEN),
|
| 279 |
+
strides=(stride_kk, stride_kn),
|
| 280 |
+
offsets=(0, kv_start),
|
| 281 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 282 |
+
order=(0, 1)
|
| 283 |
+
)
|
| 284 |
+
V_block_ptr = tl.make_block_ptr(
|
| 285 |
+
base=V,
|
| 286 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 287 |
+
strides=(stride_vn, stride_vk),
|
| 288 |
+
offsets=(kv_start, 0),
|
| 289 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 290 |
+
order=(1, 0)
|
| 291 |
+
)
|
| 292 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 293 |
+
|
| 294 |
+
acc, l_i, m_i = forward_inner(
|
| 295 |
+
{{gen_argdefs()}},
|
| 296 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 297 |
+
acc, l_i, m_i,
|
| 298 |
+
off_z, off_hq, offs_m[:, None], offs_n[None, :],
|
| 299 |
+
kv_indices, kv_num_blocks,
|
| 300 |
+
0, block_n_end,
|
| 301 |
+
MATMUL_PRECISION,
|
| 302 |
+
IS_FULL_BLOCKS=False,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 306 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 307 |
+
# apply mask_mod to them - only score_mod
|
| 308 |
+
if HAS_FULL_BLOCKS:
|
| 309 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 310 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 311 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 312 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 313 |
+
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 314 |
+
|
| 315 |
+
K_block_ptr = tl.make_block_ptr(
|
| 316 |
+
base=K,
|
| 317 |
+
shape=(QK_HEAD_DIM, KV_LEN),
|
| 318 |
+
strides=(stride_kk, stride_kn),
|
| 319 |
+
offsets=(0, kv_start),
|
| 320 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 321 |
+
order=(0, 1)
|
| 322 |
+
)
|
| 323 |
+
V_block_ptr = tl.make_block_ptr(
|
| 324 |
+
base=V,
|
| 325 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 326 |
+
strides=(stride_vn, stride_vk),
|
| 327 |
+
offsets=(kv_start, 0),
|
| 328 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 329 |
+
order=(1, 0)
|
| 330 |
+
)
|
| 331 |
+
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
| 332 |
+
|
| 333 |
+
acc, l_i, m_i = forward_inner(
|
| 334 |
+
{{gen_argdefs()}},
|
| 335 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 336 |
+
acc, l_i, m_i,
|
| 337 |
+
off_z, off_hq, offs_m[:, None], offs_n[None, :],
|
| 338 |
+
kv_indices, kv_num_blocks,
|
| 339 |
+
0, block_n_end,
|
| 340 |
+
MATMUL_PRECISION,
|
| 341 |
+
IS_FULL_BLOCKS=True,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# [Note] Handle fully masked out rows:
|
| 346 |
+
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
| 347 |
+
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
| 348 |
+
l_i = tl.where(l_i == 0.0, 1, l_i)
|
| 349 |
+
|
| 350 |
+
acc = acc / l_i[:, None]
|
| 351 |
+
idx_z = tl.program_id(1) // HQ
|
| 352 |
+
idx_hq = tl.program_id(1) % HQ
|
| 353 |
+
idx_m = offs_m[:, None]
|
| 354 |
+
idx_d = tl.arange(0, V_HEAD_DIM)[None, :]
|
| 355 |
+
|
| 356 |
+
mask = idx_m < Q_LEN
|
| 357 |
+
# TODO generalize and add proper mask support
|
| 358 |
+
{{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
| 359 |
+
|
| 360 |
+
# TODO dont want to write this if we dont require grad
|
| 361 |
+
if OUTPUT_LOGSUMEXP:
|
| 362 |
+
off_hz = tl.program_id(1)
|
| 363 |
+
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
| 364 |
+
lse = m_i + tl.math.log2(l_i)
|
| 365 |
+
if IS_DIVISIBLE:
|
| 366 |
+
tl.store(l_ptrs, lse)
|
| 367 |
+
else:
|
| 368 |
+
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
compute_forward_inner = r"""
|
| 373 |
+
@triton.jit
|
| 374 |
+
def forward_inner(
|
| 375 |
+
{{gen_argdefs()}},
|
| 376 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 377 |
+
# accumulated values
|
| 378 |
+
acc, l_i, m_i,
|
| 379 |
+
# Offsets used as inputs to score_mod & mask_mod
|
| 380 |
+
# of size [BLOCK_M, BLOCK_N] or scalar.
|
| 381 |
+
off_z, off_h, offs_m, offs_n,
|
| 382 |
+
# blocksparse data
|
| 383 |
+
kv_indices, kv_num_blocks,
|
| 384 |
+
# start kv and end kv block
|
| 385 |
+
block_n_start, block_n_end,
|
| 386 |
+
MATMUL_PRECISION,
|
| 387 |
+
IS_FULL_BLOCKS,
|
| 388 |
+
):
|
| 389 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 390 |
+
{{gen_defines() | indent_except_first(1)}}
|
| 391 |
+
|
| 392 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 393 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 394 |
+
|
| 395 |
+
if PRESCALE_QK:
|
| 396 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 397 |
+
|
| 398 |
+
# loop over k, v and update accumulator until block_n_end
|
| 399 |
+
for start_n in range(block_n_start, block_n_end):
|
| 400 |
+
if IS_DIVISIBLE:
|
| 401 |
+
acc, l_i, m_i = forward_block_mn(
|
| 402 |
+
{{gen_argdefs()}},
|
| 403 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 404 |
+
# accumulated values
|
| 405 |
+
acc, l_i, m_i,
|
| 406 |
+
# Offsets
|
| 407 |
+
off_z, off_h, offs_m, offs_n,
|
| 408 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 409 |
+
IS_FULL_BLOCKS,
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
| 413 |
+
# it's on par or slightly faster than only applying to the last block in fwd.
|
| 414 |
+
# However, we choose different strategy for bwd, where we only apply mod & mask
|
| 415 |
+
# to the last block because it's faster a lot.
|
| 416 |
+
acc, l_i, m_i = forward_block_mn(
|
| 417 |
+
{{gen_argdefs()}},
|
| 418 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 419 |
+
# accumulated values
|
| 420 |
+
acc, l_i, m_i,
|
| 421 |
+
# Offsets
|
| 422 |
+
off_z, off_h, offs_m, offs_n,
|
| 423 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 424 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# update pointers
|
| 428 |
+
offset = get_offset_for_next_block(
|
| 429 |
+
start_n, kv_indices, kv_num_blocks,
|
| 430 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
| 434 |
+
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
| 435 |
+
|
| 436 |
+
offs_n = offs_n + offset
|
| 437 |
+
|
| 438 |
+
return acc, l_i, m_i
|
| 439 |
+
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
compute_forward_block_mn = r"""
|
| 444 |
+
@triton.jit
|
| 445 |
+
def forward_block_mn(
|
| 446 |
+
{{gen_argdefs()}},
|
| 447 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 448 |
+
# accumulated values
|
| 449 |
+
acc, l_i, m_i,
|
| 450 |
+
# Offsets
|
| 451 |
+
off_z, off_h, offs_m, offs_n,
|
| 452 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 453 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 454 |
+
):
|
| 455 |
+
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
| 456 |
+
{{gen_defines() | indent_except_first(1)}}
|
| 457 |
+
|
| 458 |
+
# -- load k --
|
| 459 |
+
if IS_DIVISIBLE:
|
| 460 |
+
k = tl.load(K_block_ptr)
|
| 461 |
+
else:
|
| 462 |
+
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero")
|
| 463 |
+
# -- compute qk ---
|
| 464 |
+
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
| 465 |
+
if not PRESCALE_QK:
|
| 466 |
+
qk *= SM_SCALE
|
| 467 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 468 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 469 |
+
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
| 470 |
+
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
| 471 |
+
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
| 472 |
+
m = offs_m % Q_LEN
|
| 473 |
+
n = offs_n % KV_LEN
|
| 474 |
+
else:
|
| 475 |
+
m = offs_m
|
| 476 |
+
n = offs_n
|
| 477 |
+
|
| 478 |
+
{{ modification(
|
| 479 |
+
subgraph_number=0,
|
| 480 |
+
output_name="post_mod_scores",
|
| 481 |
+
score="qk",
|
| 482 |
+
b="off_z",
|
| 483 |
+
h="off_h",
|
| 484 |
+
m="m",
|
| 485 |
+
n="n",
|
| 486 |
+
out="qk"
|
| 487 |
+
) | indent_except_first(1) }}
|
| 488 |
+
|
| 489 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 490 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 491 |
+
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
| 492 |
+
|
| 493 |
+
if not IS_FULL_BLOCKS:
|
| 494 |
+
{{ modification(
|
| 495 |
+
subgraph_number=1,
|
| 496 |
+
output_name="mask_mod_output",
|
| 497 |
+
score="qk",
|
| 498 |
+
b="off_z",
|
| 499 |
+
h="off_h",
|
| 500 |
+
m="m",
|
| 501 |
+
n="n",
|
| 502 |
+
) | indent_except_first(2) }}
|
| 503 |
+
|
| 504 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 505 |
+
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf"))
|
| 506 |
+
# apply mask for partially unmasked blocks
|
| 507 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 508 |
+
|
| 509 |
+
# TODO: In the case that score_mod is linear, this can be LICMed
|
| 510 |
+
if not PRESCALE_QK:
|
| 511 |
+
post_mod_scores *= RCP_LN2
|
| 512 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 513 |
+
|
| 514 |
+
# -- compute scaling constant ---
|
| 515 |
+
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
| 516 |
+
if not ROWS_GUARANTEED_SAFE:
|
| 517 |
+
masked_out_rows = (m_ij == float("-inf"))
|
| 518 |
+
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
| 519 |
+
else:
|
| 520 |
+
m_ij_masked = m_ij
|
| 521 |
+
|
| 522 |
+
alpha = tl.math.exp2(m_i - m_ij_masked)
|
| 523 |
+
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
| 524 |
+
|
| 525 |
+
# NB: l_i update is pulled up here since it's a bit faster
|
| 526 |
+
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
| 527 |
+
# m_ij
|
| 528 |
+
l_i = l_i * alpha + tl.sum(p, 1)
|
| 529 |
+
# # -- scale and update acc --
|
| 530 |
+
acc = acc * alpha[:, None]
|
| 531 |
+
|
| 532 |
+
if IS_DIVISIBLE:
|
| 533 |
+
v = tl.load(V_block_ptr)
|
| 534 |
+
else:
|
| 535 |
+
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero")
|
| 536 |
+
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
| 537 |
+
|
| 538 |
+
# -- update m_i
|
| 539 |
+
m_i = m_ij
|
| 540 |
+
|
| 541 |
+
return acc, l_i, m_i
|
| 542 |
+
|
| 543 |
+
"""
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
flex_attention_template = TritonTemplate(
|
| 547 |
+
name="flex_attention",
|
| 548 |
+
grid=flex_attention_grid,
|
| 549 |
+
source=compute_flex_attention
|
| 550 |
+
+ compute_forward_inner
|
| 551 |
+
+ compute_next_offset_func
|
| 552 |
+
+ compute_forward_block_mn,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def _use_flex_decoding(query, kernel_options):
|
| 557 |
+
# Decide which kernel to use, return true if use flex decoding kernel.
|
| 558 |
+
return (
|
| 559 |
+
not kernel_options.get("FORCE_USE_FLEX_ATTENTION", False)
|
| 560 |
+
) and V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 128))
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
_h100_default_config = {
|
| 564 |
+
(torch.float32, 64): (128, 32, 4, 3),
|
| 565 |
+
(torch.float32, 128): (32, 64, 4, 3),
|
| 566 |
+
(torch.float32, 256): (32, 32, 4, 3),
|
| 567 |
+
(torch.bfloat16, 64): (128, 128, 4, 3),
|
| 568 |
+
(torch.bfloat16, 128): (128, 64, 8, 3),
|
| 569 |
+
(torch.bfloat16, 256): (64, 32, 4, 3),
|
| 570 |
+
(torch.float16, 64): (128, 128, 4, 3),
|
| 571 |
+
(torch.float16, 128): (128, 128, 8, 3),
|
| 572 |
+
(torch.float16, 256): (64, 32, 4, 3),
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
_a100_default_config = {
|
| 576 |
+
(torch.float32, 64): (128, 32, 4, 3),
|
| 577 |
+
(torch.float32, 128): (128, 32, 4, 3),
|
| 578 |
+
(torch.float32, 256): (64, 16, 4, 3),
|
| 579 |
+
(torch.bfloat16, 64): (128, 64, 4, 3),
|
| 580 |
+
(torch.bfloat16, 128): (128, 64, 8, 3),
|
| 581 |
+
(torch.bfloat16, 256): (32, 64, 4, 3),
|
| 582 |
+
(torch.float16, 64): (128, 64, 4, 3),
|
| 583 |
+
(torch.float16, 128): (128, 64, 8, 3),
|
| 584 |
+
(torch.float16, 256): (32, 64, 4, 3),
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
|
| 589 |
+
dtype = query.get_dtype()
|
| 590 |
+
head_dim = query.get_size()[-1]
|
| 591 |
+
default_config = None
|
| 592 |
+
|
| 593 |
+
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
|
| 594 |
+
if dtype == torch.float32:
|
| 595 |
+
default_config = (64, 64, 4, 3)
|
| 596 |
+
else:
|
| 597 |
+
default_config = (128, 64, 4, 3)
|
| 598 |
+
default_config = _h100_default_config.get((dtype, head_dim), default_config)
|
| 599 |
+
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
|
| 600 |
+
if dtype == torch.float32:
|
| 601 |
+
default_config = (64, 64, 4, 3)
|
| 602 |
+
else:
|
| 603 |
+
default_config = (128, 64, 4, 3)
|
| 604 |
+
default_config = _a100_default_config.get((dtype, head_dim), default_config)
|
| 605 |
+
else: # modest hardware or extremely large head_dim
|
| 606 |
+
if dtype == torch.float32:
|
| 607 |
+
default_config = (32, 16, 4, 3)
|
| 608 |
+
else:
|
| 609 |
+
default_config = (64, 32, 4, 3)
|
| 610 |
+
|
| 611 |
+
return default_config
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
|
| 615 |
+
head_dim = query.get_size()[-1]
|
| 616 |
+
dtype = query.get_dtype()
|
| 617 |
+
|
| 618 |
+
if dtype == torch.float32:
|
| 619 |
+
return (16, 16, 4, 1)
|
| 620 |
+
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
|
| 621 |
+
if head_dim == 64:
|
| 622 |
+
return (64, 64, 4, 3)
|
| 623 |
+
elif head_dim == 128:
|
| 624 |
+
return (64, 128, 8, 3)
|
| 625 |
+
else:
|
| 626 |
+
return (64, 64, 4, 2)
|
| 627 |
+
elif torch.cuda.get_device_capability() >= (8, 0): # A100
|
| 628 |
+
if head_dim == 64:
|
| 629 |
+
return (32, 128, 4, 3)
|
| 630 |
+
elif head_dim == 128:
|
| 631 |
+
return (64, 128, 8, 3)
|
| 632 |
+
else:
|
| 633 |
+
return (64, 64, 4, 2)
|
| 634 |
+
else: # modest hardware or extremely large head_dim
|
| 635 |
+
return (16, 16, 4, 1)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def create_num_blocks_fake_generator(sparse_indices):
|
| 639 |
+
# The idea here is that we need to create a real tensor with real data
|
| 640 |
+
# that's representative for benchmarking.
|
| 641 |
+
# For example, returning all zeros for the `kv_num_blocks` input would mean
|
| 642 |
+
# that we are computing 0 blocks for each row, which would provide bogus
|
| 643 |
+
# autotuning results.
|
| 644 |
+
#
|
| 645 |
+
# In this case, we choose to use min(16, max_block) blocks, because I
|
| 646 |
+
# (Horace) think it'll probably result in pretty representative performance.
|
| 647 |
+
# If it's too short then prefetching won't help. If it's too long then
|
| 648 |
+
# autotuning will take longer for no good reason.
|
| 649 |
+
def create_num_blocks_fake(x) -> torch.Tensor:
|
| 650 |
+
num_blocks_for_autotuning = min(16, sparse_indices.shape[-1])
|
| 651 |
+
return torch.full(
|
| 652 |
+
x.get_size(),
|
| 653 |
+
int(num_blocks_for_autotuning),
|
| 654 |
+
dtype=x.get_dtype(),
|
| 655 |
+
device=x.get_device(),
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
return create_num_blocks_fake
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def create_indices_fake(x) -> torch.Tensor:
|
| 662 |
+
indices = torch.arange(
|
| 663 |
+
0, int(x.get_size()[-1]), dtype=x.get_dtype(), device=x.get_device()
|
| 664 |
+
)
|
| 665 |
+
indices = indices.expand(x.get_size()).contiguous()
|
| 666 |
+
return indices
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
# TODO: We probably also need a layout constraint?
|
| 673 |
+
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
|
| 674 |
+
def flex_attention(
|
| 675 |
+
query,
|
| 676 |
+
key,
|
| 677 |
+
value,
|
| 678 |
+
subgraph,
|
| 679 |
+
block_mask,
|
| 680 |
+
scale,
|
| 681 |
+
kernel_options,
|
| 682 |
+
score_mod_other_buffers,
|
| 683 |
+
mask_mod_other_buffers,
|
| 684 |
+
):
|
| 685 |
+
(
|
| 686 |
+
kv_num_blocks,
|
| 687 |
+
kv_indices,
|
| 688 |
+
full_kv_num_blocks,
|
| 689 |
+
full_kv_indices,
|
| 690 |
+
q_num_blocks,
|
| 691 |
+
q_indices,
|
| 692 |
+
full_q_num_blocks,
|
| 693 |
+
full_q_indices,
|
| 694 |
+
SPARSE_KV_BLOCK_SIZE,
|
| 695 |
+
SPARSE_Q_BLOCK_SIZE,
|
| 696 |
+
mask_graph,
|
| 697 |
+
) = block_mask
|
| 698 |
+
placeholder_inps = [
|
| 699 |
+
create_placeholder(name, dtype, query.get_device())
|
| 700 |
+
for name, dtype in [
|
| 701 |
+
("score", query.get_dtype()),
|
| 702 |
+
("b", torch.int32),
|
| 703 |
+
("h", torch.int32),
|
| 704 |
+
("m", torch.int32),
|
| 705 |
+
("n", torch.int32),
|
| 706 |
+
]
|
| 707 |
+
]
|
| 708 |
+
subgraph_buffer = build_subgraph_buffer(
|
| 709 |
+
placeholder_inps + list(score_mod_other_buffers), subgraph
|
| 710 |
+
)
|
| 711 |
+
mask_graph_placeholder_inps = [
|
| 712 |
+
create_placeholder(name, dtype, query.get_device())
|
| 713 |
+
for name, dtype in [
|
| 714 |
+
("b", torch.int32),
|
| 715 |
+
("h", torch.int32),
|
| 716 |
+
("m", torch.int32),
|
| 717 |
+
("n", torch.int32),
|
| 718 |
+
]
|
| 719 |
+
]
|
| 720 |
+
mask_graph_buffer = build_subgraph_buffer(
|
| 721 |
+
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
| 722 |
+
)
|
| 723 |
+
kernel_options = dict(kernel_options)
|
| 724 |
+
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
|
| 725 |
+
if _use_flex_decoding(query, kernel_options):
|
| 726 |
+
return create_flex_decoding_kernel(
|
| 727 |
+
query,
|
| 728 |
+
key,
|
| 729 |
+
value,
|
| 730 |
+
block_mask,
|
| 731 |
+
scale,
|
| 732 |
+
kernel_options,
|
| 733 |
+
subgraph_buffer,
|
| 734 |
+
mask_graph_buffer,
|
| 735 |
+
score_mod_other_buffers,
|
| 736 |
+
mask_mod_other_buffers,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
(
|
| 740 |
+
query,
|
| 741 |
+
key,
|
| 742 |
+
value,
|
| 743 |
+
kv_num_blocks,
|
| 744 |
+
kv_indices,
|
| 745 |
+
full_kv_num_blocks,
|
| 746 |
+
full_kv_indices,
|
| 747 |
+
q_num_blocks,
|
| 748 |
+
q_indices,
|
| 749 |
+
full_q_num_blocks,
|
| 750 |
+
full_q_indices,
|
| 751 |
+
) = maybe_realize(
|
| 752 |
+
[
|
| 753 |
+
query,
|
| 754 |
+
key,
|
| 755 |
+
value,
|
| 756 |
+
kv_num_blocks,
|
| 757 |
+
kv_indices,
|
| 758 |
+
full_kv_num_blocks,
|
| 759 |
+
full_kv_indices,
|
| 760 |
+
q_num_blocks,
|
| 761 |
+
q_indices,
|
| 762 |
+
full_q_num_blocks,
|
| 763 |
+
full_q_indices,
|
| 764 |
+
]
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
| 768 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
| 769 |
+
assert Bq == Bkv, "Batch dimension must match"
|
| 770 |
+
B = Bq
|
| 771 |
+
|
| 772 |
+
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
|
| 773 |
+
kernel_options.setdefault("IS_DIVISIBLE", False)
|
| 774 |
+
else:
|
| 775 |
+
kernel_options.setdefault("IS_DIVISIBLE", True)
|
| 776 |
+
|
| 777 |
+
# Reuse query strides for output layout despite different last dimension.
|
| 778 |
+
# This works because only the last dim differs and we check it is contiguous.
|
| 779 |
+
q_strides = query.get_stride()
|
| 780 |
+
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
|
| 781 |
+
|
| 782 |
+
# Construct output layout with strides matching the query.
|
| 783 |
+
out_size = [B, Hq, seq_len_q, v_head_dim]
|
| 784 |
+
stride_order = get_stride_order(query.get_stride())
|
| 785 |
+
fill_order = stride_order2fill_order(stride_order)
|
| 786 |
+
out_strides = construct_strides(out_size, fill_order)
|
| 787 |
+
|
| 788 |
+
layout = FixedLayout(
|
| 789 |
+
query.get_device(),
|
| 790 |
+
query.get_dtype(),
|
| 791 |
+
[B, Hq, seq_len_q, v_head_dim],
|
| 792 |
+
stride=out_strides,
|
| 793 |
+
)
|
| 794 |
+
# see NOTE:[TritonTemplates with multiple outputs]
|
| 795 |
+
logsumexp_shape = [B, Hq, seq_len_q]
|
| 796 |
+
logsumexp = empty_strided(
|
| 797 |
+
logsumexp_shape,
|
| 798 |
+
None,
|
| 799 |
+
dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype
|
| 800 |
+
device=query.get_device(),
|
| 801 |
+
)
|
| 802 |
+
kernel_options.setdefault("SM_SCALE", scale)
|
| 803 |
+
|
| 804 |
+
# Determine GQA broadcast factor.
|
| 805 |
+
gqa_shared_heads = Hq // Hkv
|
| 806 |
+
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
|
| 807 |
+
|
| 808 |
+
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
|
| 809 |
+
# full_kv_num_blocks is None if partial blocks are not computed
|
| 810 |
+
has_full_blocks = full_kv_num_blocks is not None
|
| 811 |
+
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
|
| 812 |
+
if not has_full_blocks:
|
| 813 |
+
full_kv_num_blocks, full_kv_indices = (
|
| 814 |
+
empty(0, device=query.get_device()) for _ in range(2)
|
| 815 |
+
)
|
| 816 |
+
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
|
| 817 |
+
kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
|
| 818 |
+
|
| 819 |
+
choices: List[Any] = []
|
| 820 |
+
configs: List[Tuple[int, int, int, int]] = []
|
| 821 |
+
configs.append(_get_default_config_fwd(query))
|
| 822 |
+
if config.max_autotune:
|
| 823 |
+
configs += [
|
| 824 |
+
(128, 64, 4, 3),
|
| 825 |
+
(128, 128, 4, 3),
|
| 826 |
+
(128, 128, 8, 2),
|
| 827 |
+
(64, 128, 4, 3),
|
| 828 |
+
(64, 64, 4, 3),
|
| 829 |
+
]
|
| 830 |
+
|
| 831 |
+
# Note, we don't need to pass in the captured buffers explicitly
|
| 832 |
+
# because they're implicitly added by the score_mod function
|
| 833 |
+
# We do need to explicitly pass it in for autotuning though.
|
| 834 |
+
|
| 835 |
+
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
|
| 836 |
+
if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0:
|
| 837 |
+
continue
|
| 838 |
+
# Work around https://github.com/pytorch/pytorch/issues/129625
|
| 839 |
+
if num_stages == 2:
|
| 840 |
+
continue
|
| 841 |
+
|
| 842 |
+
# Performance tuning
|
| 843 |
+
kernel_options.setdefault("BLOCK_M", BLOCK_M)
|
| 844 |
+
kernel_options.setdefault("BLOCK_N", BLOCK_N)
|
| 845 |
+
# Blocksparse options
|
| 846 |
+
kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
| 847 |
+
kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
| 848 |
+
|
| 849 |
+
flex_attention_template.maybe_append_choice(
|
| 850 |
+
choices=choices,
|
| 851 |
+
input_nodes=[
|
| 852 |
+
query,
|
| 853 |
+
key,
|
| 854 |
+
value,
|
| 855 |
+
logsumexp,
|
| 856 |
+
kv_num_blocks,
|
| 857 |
+
kv_indices,
|
| 858 |
+
full_kv_num_blocks,
|
| 859 |
+
full_kv_indices,
|
| 860 |
+
],
|
| 861 |
+
layout=layout,
|
| 862 |
+
subgraphs=[
|
| 863 |
+
subgraph_buffer,
|
| 864 |
+
mask_graph_buffer,
|
| 865 |
+
],
|
| 866 |
+
mutated_inputs=[
|
| 867 |
+
logsumexp,
|
| 868 |
+
],
|
| 869 |
+
num_stages=num_stages,
|
| 870 |
+
num_warps=num_warps,
|
| 871 |
+
call_sizes=query.get_size(),
|
| 872 |
+
**kernel_options,
|
| 873 |
+
)
|
| 874 |
+
inputs_for_autotuning = (
|
| 875 |
+
[
|
| 876 |
+
query,
|
| 877 |
+
key,
|
| 878 |
+
value,
|
| 879 |
+
logsumexp,
|
| 880 |
+
kv_num_blocks,
|
| 881 |
+
kv_indices,
|
| 882 |
+
full_kv_num_blocks,
|
| 883 |
+
full_kv_indices,
|
| 884 |
+
]
|
| 885 |
+
+ list(score_mod_other_buffers)
|
| 886 |
+
+ list(mask_mod_other_buffers)
|
| 887 |
+
)
|
| 888 |
+
input_gen_fns = {
|
| 889 |
+
4: create_num_blocks_fake_generator(kv_indices),
|
| 890 |
+
5: create_indices_fake,
|
| 891 |
+
6: create_num_blocks_fake_generator(full_kv_indices),
|
| 892 |
+
7: create_indices_fake,
|
| 893 |
+
}
|
| 894 |
+
return (
|
| 895 |
+
autotune_select_algorithm(
|
| 896 |
+
"flex_attention",
|
| 897 |
+
choices,
|
| 898 |
+
inputs_for_autotuning,
|
| 899 |
+
layout,
|
| 900 |
+
input_gen_fns=input_gen_fns,
|
| 901 |
+
),
|
| 902 |
+
logsumexp,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
# ---------------------------- Backward HOP Implementation ----------------------------
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def flex_attention_backward_grid(
|
| 910 |
+
batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
|
| 911 |
+
):
|
| 912 |
+
"""How is this kernel parallelized?
|
| 913 |
+
Currently this is only parallelizing over batch* kv_heads, but we can, and want to
|
| 914 |
+
parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
|
| 915 |
+
To do this will either require atomic updates to some grad values or to have a two pass kernel design.
|
| 916 |
+
"""
|
| 917 |
+
import triton
|
| 918 |
+
|
| 919 |
+
return (
|
| 920 |
+
triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
|
| 921 |
+
+ triton.cdiv(num_key_value, meta["BLOCK_N1"]),
|
| 922 |
+
1,
|
| 923 |
+
batch_size * kv_heads,
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
flex_attention_backward_template = TritonTemplate(
|
| 928 |
+
name="flex_attention_backward",
|
| 929 |
+
grid=flex_attention_backward_grid,
|
| 930 |
+
source=r"""
|
| 931 |
+
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
|
| 932 |
+
# Sub notation for this kernel:
|
| 933 |
+
#
|
| 934 |
+
# Q: Query, K: Key, V: Value
|
| 935 |
+
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
| 936 |
+
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
| 937 |
+
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
| 938 |
+
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
| 939 |
+
# inductor codegen
|
| 940 |
+
# M: Number of queries, N: Number of keys/values
|
| 941 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 942 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 943 |
+
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
| 944 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 945 |
+
# (Modifiable) Performance tuning options
|
| 946 |
+
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
| 947 |
+
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
| 948 |
+
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
| 949 |
+
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
| 950 |
+
#
|
| 951 |
+
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
| 952 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 953 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 954 |
+
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
| 955 |
+
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
| 956 |
+
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 957 |
+
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
| 958 |
+
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 959 |
+
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
| 960 |
+
|
| 961 |
+
# The below are kernel options that can be applied for certain score_mods,
|
| 962 |
+
# or involve a numerics vs. perf tradeoff
|
| 963 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
| 964 |
+
# about 20% more numerical error, but slightly faster.
|
| 965 |
+
|
| 966 |
+
# Define strides of inputs
|
| 967 |
+
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
|
| 968 |
+
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
|
| 969 |
+
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
|
| 970 |
+
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
|
| 971 |
+
|
| 972 |
+
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
|
| 973 |
+
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
|
| 974 |
+
|
| 975 |
+
Z = {{size("Q", 0)}}
|
| 976 |
+
HQ = {{size("Q", 1)}}
|
| 977 |
+
HKV = {{size("K", 1)}}
|
| 978 |
+
Q_LEN = {{size("Q", 2)}}
|
| 979 |
+
KV_LEN = {{size("K", 2)}}
|
| 980 |
+
|
| 981 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 982 |
+
|
| 983 |
+
pid = tl.program_id(0)
|
| 984 |
+
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
| 985 |
+
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
| 986 |
+
|
| 987 |
+
off_hz = tl.program_id(2)
|
| 988 |
+
off_z = off_hz // HKV # batch idx
|
| 989 |
+
off_hkv = off_hz % HKV # kv head idx
|
| 990 |
+
|
| 991 |
+
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
| 992 |
+
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
| 993 |
+
|
| 994 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 995 |
+
|
| 996 |
+
k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
|
| 997 |
+
v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
|
| 998 |
+
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
|
| 999 |
+
|
| 1000 |
+
# offset K, V, DV pointers for batch/kv-head
|
| 1001 |
+
K += k_adj
|
| 1002 |
+
V += v_adj
|
| 1003 |
+
DV += dv_adj
|
| 1004 |
+
|
| 1005 |
+
RCP_LN2 = 1.44269504
|
| 1006 |
+
offs_k = tl.arange(0, QK_HEAD_DIM)
|
| 1007 |
+
offs_v = tl.arange(0, V_HEAD_DIM)
|
| 1008 |
+
|
| 1009 |
+
if pid >= NUM_KV_BLOCKS:
|
| 1010 |
+
off_pid = pid - NUM_KV_BLOCKS
|
| 1011 |
+
# THIS BLOCK DOES DQ
|
| 1012 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
| 1013 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 1014 |
+
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
| 1015 |
+
start_m2_block = off_pid % NUM_Q_BLOCKS
|
| 1016 |
+
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
| 1017 |
+
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
| 1018 |
+
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
| 1019 |
+
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
| 1020 |
+
|
| 1021 |
+
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
| 1022 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
| 1023 |
+
|
| 1024 |
+
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
| 1025 |
+
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
| 1026 |
+
|
| 1027 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
|
| 1028 |
+
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
|
| 1029 |
+
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
|
| 1030 |
+
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
|
| 1031 |
+
off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
| 1032 |
+
|
| 1033 |
+
Q2 = Q + q_adj2
|
| 1034 |
+
DO2 = DO + do_adj2
|
| 1035 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 1036 |
+
# if Q is broadcasted)
|
| 1037 |
+
DQ2 = DQ + dq_adj2
|
| 1038 |
+
LSE2 = LSE + off_chz2
|
| 1039 |
+
DELTA2 = DELTA + off_chz2
|
| 1040 |
+
|
| 1041 |
+
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
| 1042 |
+
|
| 1043 |
+
start_m2 = start_m2_block * BLOCK_M2
|
| 1044 |
+
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
| 1045 |
+
|
| 1046 |
+
# load Q and do: they stay in SRAM throughout the inner loop.
|
| 1047 |
+
if IS_DIVISIBLE:
|
| 1048 |
+
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
|
| 1049 |
+
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod)
|
| 1050 |
+
else:
|
| 1051 |
+
q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=offs_m2[:, None] < Q_LEN)
|
| 1052 |
+
do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod, mask=offs_m2[:, None] < Q_LEN)
|
| 1053 |
+
|
| 1054 |
+
if PRESCALE_QK:
|
| 1055 |
+
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 1056 |
+
|
| 1057 |
+
if IS_DIVISIBLE:
|
| 1058 |
+
Di = tl.load(DELTA2 + offs_m2)
|
| 1059 |
+
lse = tl.load(LSE2 + offs_m2)
|
| 1060 |
+
else:
|
| 1061 |
+
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 1062 |
+
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
| 1063 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 1064 |
+
lse = lse[:, None]
|
| 1065 |
+
|
| 1066 |
+
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1067 |
+
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
| 1068 |
+
kv_indices = KV_IDX + sparse_kv_idx_offset
|
| 1069 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 1070 |
+
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 1071 |
+
|
| 1072 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 1073 |
+
dq = bwd_dq_inner(
|
| 1074 |
+
{{gen_argdefs()}},
|
| 1075 |
+
K, V,
|
| 1076 |
+
dq, q, do, Di, lse,
|
| 1077 |
+
off_z, off_hq2, offs_m2, offs_n2,
|
| 1078 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1079 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1080 |
+
MATMUL_PRECISION,
|
| 1081 |
+
IS_FULL_BLOCKS=False,
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
if HAS_FULL_BLOCKS:
|
| 1085 |
+
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1086 |
+
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
| 1087 |
+
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
| 1088 |
+
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
| 1089 |
+
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
| 1090 |
+
|
| 1091 |
+
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
| 1092 |
+
dq = bwd_dq_inner(
|
| 1093 |
+
{{gen_argdefs()}},
|
| 1094 |
+
K, V,
|
| 1095 |
+
dq, q, do, Di, lse,
|
| 1096 |
+
off_z, off_hq2, offs_m2, offs_n2,
|
| 1097 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1098 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1099 |
+
MATMUL_PRECISION,
|
| 1100 |
+
IS_FULL_BLOCKS=True,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
# Write back dQ.
|
| 1104 |
+
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
| 1105 |
+
dq *= SM_SCALE
|
| 1106 |
+
if IS_DIVISIBLE:
|
| 1107 |
+
tl.store(dq_ptrs, dq)
|
| 1108 |
+
else:
|
| 1109 |
+
tl.store(dq_ptrs, dq, mask=offs_m2[:, None] < Q_LEN)
|
| 1110 |
+
else:
|
| 1111 |
+
# THIS BLOCK DOES DK & DV
|
| 1112 |
+
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 1113 |
+
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
| 1114 |
+
|
| 1115 |
+
pid_mask = pid // SPARSE_KV_MULTIPLE
|
| 1116 |
+
|
| 1117 |
+
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
|
| 1118 |
+
stride_q_idx_h = {{stride("Q_IDX", 1)}}
|
| 1119 |
+
stride_q_idx_n = {{stride("Q_IDX", 2)}}
|
| 1120 |
+
|
| 1121 |
+
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM], dtype=tl.float32)
|
| 1122 |
+
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM], dtype=tl.float32)
|
| 1123 |
+
|
| 1124 |
+
start_n1 = pid * BLOCK_N1
|
| 1125 |
+
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
| 1126 |
+
|
| 1127 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 1128 |
+
if IS_DIVISIBLE:
|
| 1129 |
+
k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
|
| 1130 |
+
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd)
|
| 1131 |
+
else:
|
| 1132 |
+
k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=offs_n1[:, None] < KV_LEN)
|
| 1133 |
+
v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd, mask=offs_n1[:, None] < KV_LEN)
|
| 1134 |
+
if PRESCALE_QK:
|
| 1135 |
+
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
| 1136 |
+
|
| 1137 |
+
for off_g in range(0, GQA_SHARED_HEADS):
|
| 1138 |
+
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
| 1139 |
+
|
| 1140 |
+
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
|
| 1141 |
+
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
|
| 1142 |
+
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
|
| 1143 |
+
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
|
| 1144 |
+
off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
| 1145 |
+
|
| 1146 |
+
Q1 = Q + q_adj1
|
| 1147 |
+
DO1 = DO + do_adj1
|
| 1148 |
+
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
| 1149 |
+
# if Q is broadcasted)
|
| 1150 |
+
LSE1 = LSE + off_chz1
|
| 1151 |
+
DELTA1 = DELTA + off_chz1
|
| 1152 |
+
|
| 1153 |
+
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
| 1154 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
| 1155 |
+
|
| 1156 |
+
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
| 1157 |
+
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
| 1158 |
+
|
| 1159 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1160 |
+
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
| 1161 |
+
q_indices = Q_IDX + sparse_q_idx_offset
|
| 1162 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 1163 |
+
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 1164 |
+
|
| 1165 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 1166 |
+
dk, dv = bwd_dkdv_inner(
|
| 1167 |
+
{{gen_argdefs()}},
|
| 1168 |
+
Q1, DO1, DELTA1, LSE1,
|
| 1169 |
+
dk, dv, k, v,
|
| 1170 |
+
off_z, off_hq1, offs_n1, offs_m1,
|
| 1171 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1172 |
+
q_indices, sparse_q_num_blocks,
|
| 1173 |
+
MATMUL_PRECISION,
|
| 1174 |
+
IS_FULL_BLOCKS=False,
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
|
| 1178 |
+
if HAS_FULL_BLOCKS:
|
| 1179 |
+
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1180 |
+
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
| 1181 |
+
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
| 1182 |
+
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
| 1183 |
+
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
| 1184 |
+
|
| 1185 |
+
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
| 1186 |
+
dk, dv = bwd_dkdv_inner(
|
| 1187 |
+
{{gen_argdefs()}},
|
| 1188 |
+
Q1, DO1, DELTA1, LSE1,
|
| 1189 |
+
dk, dv, k, v,
|
| 1190 |
+
off_z, off_hq1, offs_n1, offs_m1,
|
| 1191 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1192 |
+
q_indices, sparse_q_num_blocks,
|
| 1193 |
+
MATMUL_PRECISION,
|
| 1194 |
+
IS_FULL_BLOCKS=True,
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
# Write back dV and dK.
|
| 1198 |
+
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
| 1199 |
+
|
| 1200 |
+
index_n = offs_n1[:, None]
|
| 1201 |
+
index_k = offs_k[None, :]
|
| 1202 |
+
|
| 1203 |
+
if IS_DIVISIBLE:
|
| 1204 |
+
tl.store(dv_ptrs, dv)
|
| 1205 |
+
else:
|
| 1206 |
+
tl.store(dv_ptrs, dv, mask=index_n < KV_LEN)
|
| 1207 |
+
|
| 1208 |
+
dk *= SM_SCALE
|
| 1209 |
+
mask = index_n < KV_LEN
|
| 1210 |
+
{{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
| 1211 |
+
|
| 1212 |
+
@triton.jit
|
| 1213 |
+
def bwd_dq_inner(
|
| 1214 |
+
{{gen_argdefs()}},
|
| 1215 |
+
K, V, # pointers
|
| 1216 |
+
dq, q, do, Di, lse,
|
| 1217 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1218 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1219 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1220 |
+
MATMUL_PRECISION,
|
| 1221 |
+
IS_FULL_BLOCKS,
|
| 1222 |
+
):
|
| 1223 |
+
{{gen_defines() | indent_except_first(1) }}
|
| 1224 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
| 1225 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 1226 |
+
Q_LEN = {{size("Q", 2)}}
|
| 1227 |
+
KV_LEN = {{size("K", 2)}}
|
| 1228 |
+
|
| 1229 |
+
offs_k = tl.arange(0, QK_HEAD_DIM)
|
| 1230 |
+
offs_v = tl.arange(0, V_HEAD_DIM)
|
| 1231 |
+
|
| 1232 |
+
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
| 1233 |
+
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
| 1234 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 1235 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 1236 |
+
|
| 1237 |
+
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
| 1238 |
+
if not IS_DIVISIBLE:
|
| 1239 |
+
if hi >= 1:
|
| 1240 |
+
for start_n in range(0, hi - 1):
|
| 1241 |
+
dq = bwd_dq_block_mn(
|
| 1242 |
+
{{gen_argdefs()}},
|
| 1243 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1244 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1245 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1246 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1247 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1248 |
+
IS_FULL_BLOCKS,
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
# Increment pointers.
|
| 1252 |
+
offset = get_offset_for_next_block(
|
| 1253 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 1254 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
kT_ptrs += offset * stride_kn
|
| 1258 |
+
vT_ptrs += offset * stride_vn
|
| 1259 |
+
|
| 1260 |
+
offs_n2 += offset
|
| 1261 |
+
|
| 1262 |
+
dq = bwd_dq_block_mn(
|
| 1263 |
+
{{gen_argdefs()}},
|
| 1264 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1265 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1266 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1267 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1268 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1269 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 1270 |
+
)
|
| 1271 |
+
else:
|
| 1272 |
+
for start_n in range(0, hi):
|
| 1273 |
+
dq = bwd_dq_block_mn(
|
| 1274 |
+
{{gen_argdefs()}},
|
| 1275 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1276 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1277 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1278 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1279 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1280 |
+
IS_FULL_BLOCKS,
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
# Increment pointers.
|
| 1284 |
+
offset = get_offset_for_next_block(
|
| 1285 |
+
start_n, kv_indices, sparse_kv_num_blocks,
|
| 1286 |
+
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
kT_ptrs += offset * stride_kn
|
| 1290 |
+
vT_ptrs += offset * stride_vn
|
| 1291 |
+
|
| 1292 |
+
offs_n2 += offset
|
| 1293 |
+
|
| 1294 |
+
return dq
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
@triton.jit
|
| 1298 |
+
def bwd_dq_block_mn(
|
| 1299 |
+
{{gen_argdefs()}},
|
| 1300 |
+
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
| 1301 |
+
off_z, off_hq, offs_m2, offs_n2,
|
| 1302 |
+
stride_kn, stride_kd, stride_vn, stride_vd,
|
| 1303 |
+
kv_indices, sparse_kv_num_blocks,
|
| 1304 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1305 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 1306 |
+
):
|
| 1307 |
+
{{gen_defines() | indent_except_first(1)}}
|
| 1308 |
+
|
| 1309 |
+
if IS_DIVISIBLE:
|
| 1310 |
+
kT = tl.load(kT_ptrs)
|
| 1311 |
+
else:
|
| 1312 |
+
kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN)
|
| 1313 |
+
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
| 1314 |
+
if not PRESCALE_QK:
|
| 1315 |
+
qk *= SM_SCALE
|
| 1316 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 1317 |
+
pre_mod_scores = qk
|
| 1318 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1319 |
+
m = offs_m2[:, None] % Q_LEN
|
| 1320 |
+
n = offs_n2[None, :] % KV_LEN
|
| 1321 |
+
else:
|
| 1322 |
+
m = offs_m2[:, None]
|
| 1323 |
+
n = offs_n2[None, :]
|
| 1324 |
+
{{ modification(
|
| 1325 |
+
subgraph_number=0,
|
| 1326 |
+
output_name="post_mod_scores",
|
| 1327 |
+
score="qk",
|
| 1328 |
+
b="off_z",
|
| 1329 |
+
h="off_hq",
|
| 1330 |
+
m="m",
|
| 1331 |
+
n="n",
|
| 1332 |
+
out="qk"
|
| 1333 |
+
) | indent_except_first(1) }}
|
| 1334 |
+
|
| 1335 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1336 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 1337 |
+
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
| 1338 |
+
|
| 1339 |
+
if not IS_FULL_BLOCKS:
|
| 1340 |
+
{{ modification(
|
| 1341 |
+
subgraph_number=2,
|
| 1342 |
+
output_name="mask_mod_output",
|
| 1343 |
+
score="qk",
|
| 1344 |
+
b="off_z",
|
| 1345 |
+
h="off_hq",
|
| 1346 |
+
m="m",
|
| 1347 |
+
n="n",
|
| 1348 |
+
) | indent_except_first(2) }}
|
| 1349 |
+
|
| 1350 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1351 |
+
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1352 |
+
# apply mask for partial masked block
|
| 1353 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 1354 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1355 |
+
if not PRESCALE_QK:
|
| 1356 |
+
post_mod_scores *= RCP_LN2
|
| 1357 |
+
p = tl.math.exp2(post_mod_scores - lse)
|
| 1358 |
+
# Compute dP and dS.
|
| 1359 |
+
if IS_DIVISIBLE:
|
| 1360 |
+
vT = tl.load(vT_ptrs)
|
| 1361 |
+
else:
|
| 1362 |
+
vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN)
|
| 1363 |
+
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
| 1364 |
+
ds = p * (dp - Di[:, None])
|
| 1365 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 1366 |
+
{{ modification(
|
| 1367 |
+
subgraph_number=1,
|
| 1368 |
+
output_name = "grad_scores",
|
| 1369 |
+
score="pre_mod_scores",
|
| 1370 |
+
b="off_z",
|
| 1371 |
+
h="off_hq",
|
| 1372 |
+
m="m",
|
| 1373 |
+
n="n",
|
| 1374 |
+
grad_score_mod="ds"
|
| 1375 |
+
) | indent_except_first(1) }}
|
| 1376 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1377 |
+
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
| 1378 |
+
|
| 1379 |
+
ds = grad_scores
|
| 1380 |
+
|
| 1381 |
+
if not IS_FULL_BLOCKS:
|
| 1382 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1383 |
+
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1384 |
+
# (grads) apply mask for partially unmasked block
|
| 1385 |
+
ds = tl.where(mask_mod_output, ds, 0.0)
|
| 1386 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1387 |
+
ds = ds.to(MATMUL_PRECISION)
|
| 1388 |
+
# Compute dQ.
|
| 1389 |
+
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
| 1390 |
+
|
| 1391 |
+
return dq
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
@triton.jit
|
| 1395 |
+
def bwd_dkdv_inner(
|
| 1396 |
+
{{gen_argdefs()}},
|
| 1397 |
+
Q, DO, DELTA, LSE, # pointers
|
| 1398 |
+
dk, dv, k, v,
|
| 1399 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1400 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1401 |
+
q_indices, sparse_q_num_blocks,
|
| 1402 |
+
MATMUL_PRECISION,
|
| 1403 |
+
IS_FULL_BLOCKS,
|
| 1404 |
+
):
|
| 1405 |
+
{{gen_defines() | indent_except_first(1) }}
|
| 1406 |
+
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
| 1407 |
+
RCP_LN2: tl.constexpr = 1.44269504
|
| 1408 |
+
Q_LEN = {{size("Q", 2)}}
|
| 1409 |
+
KV_LEN = {{size("K", 2)}}
|
| 1410 |
+
|
| 1411 |
+
offs_k = tl.arange(0, QK_HEAD_DIM)
|
| 1412 |
+
offs_v = tl.arange(0, V_HEAD_DIM)
|
| 1413 |
+
|
| 1414 |
+
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
| 1415 |
+
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
| 1416 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 1417 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 1418 |
+
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
| 1419 |
+
|
| 1420 |
+
if not IS_DIVISIBLE:
|
| 1421 |
+
if hi >= 1:
|
| 1422 |
+
for start_m in range(0, hi - 1):
|
| 1423 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 1424 |
+
{{gen_argdefs()}},
|
| 1425 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1426 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1427 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1428 |
+
q_indices, sparse_q_num_blocks,
|
| 1429 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1430 |
+
IS_FULL_BLOCKS,
|
| 1431 |
+
)
|
| 1432 |
+
# Increment pointers.
|
| 1433 |
+
offset = get_offset_for_next_block(
|
| 1434 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 1435 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
qT_ptrs += offset * stride_qm
|
| 1439 |
+
do_ptrs += offset * stride_dom
|
| 1440 |
+
|
| 1441 |
+
offs_m1 += offset
|
| 1442 |
+
|
| 1443 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 1444 |
+
{{gen_argdefs()}},
|
| 1445 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1446 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1447 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1448 |
+
q_indices, sparse_q_num_blocks,
|
| 1449 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1450 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
| 1451 |
+
)
|
| 1452 |
+
else:
|
| 1453 |
+
for start_m in range(0, hi):
|
| 1454 |
+
dk, dv = bwd_dkdv_block_mn(
|
| 1455 |
+
{{gen_argdefs()}},
|
| 1456 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1457 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1458 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1459 |
+
q_indices, sparse_q_num_blocks,
|
| 1460 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1461 |
+
IS_FULL_BLOCKS,
|
| 1462 |
+
)
|
| 1463 |
+
# Increment pointers.
|
| 1464 |
+
offset = get_offset_for_next_block(
|
| 1465 |
+
start_m, q_indices, sparse_q_num_blocks,
|
| 1466 |
+
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
|
| 1467 |
+
)
|
| 1468 |
+
|
| 1469 |
+
qT_ptrs += offset * stride_qm
|
| 1470 |
+
do_ptrs += offset * stride_dom
|
| 1471 |
+
|
| 1472 |
+
offs_m1 += offset
|
| 1473 |
+
|
| 1474 |
+
return dk, dv
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
@triton.jit
|
| 1478 |
+
def bwd_dkdv_block_mn(
|
| 1479 |
+
{{gen_argdefs()}},
|
| 1480 |
+
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
| 1481 |
+
off_z, off_hq, offs_n1, offs_m1,
|
| 1482 |
+
stride_qm, stride_qd, stride_dom, stride_dod,
|
| 1483 |
+
q_indices, sparse_q_num_blocks,
|
| 1484 |
+
MATMUL_PRECISION, RCP_LN2,
|
| 1485 |
+
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
| 1486 |
+
):
|
| 1487 |
+
{{gen_defines() | indent_except_first(1) }}
|
| 1488 |
+
|
| 1489 |
+
# Load LSE before computing qk to reduce pipeline stall.
|
| 1490 |
+
if IS_DIVISIBLE:
|
| 1491 |
+
qT = tl.load(qT_ptrs)
|
| 1492 |
+
lse = tl.load(LSE + offs_m1)
|
| 1493 |
+
else:
|
| 1494 |
+
qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN)
|
| 1495 |
+
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
| 1496 |
+
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
| 1497 |
+
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
| 1498 |
+
if not PRESCALE_QK:
|
| 1499 |
+
qkT *= SM_SCALE
|
| 1500 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
| 1501 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1502 |
+
m = offs_m1[None, :] % Q_LEN
|
| 1503 |
+
n = offs_n1[:, None] % KV_LEN
|
| 1504 |
+
else:
|
| 1505 |
+
m = offs_m1[None, :]
|
| 1506 |
+
n = offs_n1[:, None]
|
| 1507 |
+
pre_mod_scores = qkT
|
| 1508 |
+
{{ modification(
|
| 1509 |
+
subgraph_number=0,
|
| 1510 |
+
output_name="post_mod_scores",
|
| 1511 |
+
score="qkT",
|
| 1512 |
+
b="off_z",
|
| 1513 |
+
h="off_hq",
|
| 1514 |
+
m="m",
|
| 1515 |
+
n="n",
|
| 1516 |
+
out="qkT"
|
| 1517 |
+
) | indent_except_first(1) }}
|
| 1518 |
+
|
| 1519 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1520 |
+
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
| 1521 |
+
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
|
| 1522 |
+
|
| 1523 |
+
if not IS_FULL_BLOCKS:
|
| 1524 |
+
{{ modification(
|
| 1525 |
+
subgraph_number=2,
|
| 1526 |
+
output_name="mask_mod_output",
|
| 1527 |
+
score="qkT",
|
| 1528 |
+
b="off_z",
|
| 1529 |
+
h="off_hq",
|
| 1530 |
+
m="m",
|
| 1531 |
+
n="n",
|
| 1532 |
+
) | indent_except_first(2) }}
|
| 1533 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1534 |
+
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1535 |
+
# (grads) apply mask for fully masked block
|
| 1536 |
+
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
| 1537 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1538 |
+
if not PRESCALE_QK:
|
| 1539 |
+
post_mod_scores *= RCP_LN2
|
| 1540 |
+
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
| 1541 |
+
if IS_DIVISIBLE:
|
| 1542 |
+
do = tl.load(do_ptrs)
|
| 1543 |
+
else:
|
| 1544 |
+
do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN)
|
| 1545 |
+
# Compute dV.
|
| 1546 |
+
ppT = pT
|
| 1547 |
+
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
| 1548 |
+
if IS_DIVISIBLE:
|
| 1549 |
+
Di = tl.load(DELTA + offs_m1)
|
| 1550 |
+
else:
|
| 1551 |
+
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
| 1552 |
+
# Compute dP and dS.
|
| 1553 |
+
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
| 1554 |
+
dsT = pT * (dpT - Di[None, :])
|
| 1555 |
+
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
| 1556 |
+
{{ modification(
|
| 1557 |
+
subgraph_number=1,
|
| 1558 |
+
output_name = "grad_scores",
|
| 1559 |
+
score="pre_mod_scores",
|
| 1560 |
+
b="off_z",
|
| 1561 |
+
h="off_hq",
|
| 1562 |
+
m="m",
|
| 1563 |
+
n="n",
|
| 1564 |
+
grad_score_mod="dsT"
|
| 1565 |
+
) | indent_except_first(1) }}
|
| 1566 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1567 |
+
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
|
| 1568 |
+
|
| 1569 |
+
dsT = grad_scores
|
| 1570 |
+
if not IS_FULL_BLOCKS:
|
| 1571 |
+
if CHECK_BLOCK_BOUNDARY:
|
| 1572 |
+
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
|
| 1573 |
+
# (grads) apply mask for partially unmasked block
|
| 1574 |
+
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
| 1575 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 1576 |
+
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
| 1577 |
+
|
| 1578 |
+
return dk, dv
|
| 1579 |
+
"""
|
| 1580 |
+
+ compute_next_offset_func,
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
|
| 1584 |
+
# TODO: We probably also need a layout constraint?
|
| 1585 |
+
@register_lowering(
|
| 1586 |
+
torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
|
| 1587 |
+
)
|
| 1588 |
+
def flex_attention_backward(*args, **kwargs):
|
| 1589 |
+
(
|
| 1590 |
+
query,
|
| 1591 |
+
key,
|
| 1592 |
+
value,
|
| 1593 |
+
out,
|
| 1594 |
+
logsumexp,
|
| 1595 |
+
grad_out,
|
| 1596 |
+
grad_logsumexp,
|
| 1597 |
+
fw_graph,
|
| 1598 |
+
joint_graph,
|
| 1599 |
+
block_mask,
|
| 1600 |
+
scale,
|
| 1601 |
+
kernel_options,
|
| 1602 |
+
score_mod_other_buffers,
|
| 1603 |
+
mask_mod_other_buffers,
|
| 1604 |
+
) = args
|
| 1605 |
+
(
|
| 1606 |
+
kv_num_blocks,
|
| 1607 |
+
kv_indices,
|
| 1608 |
+
full_kv_num_blocks,
|
| 1609 |
+
full_kv_indices,
|
| 1610 |
+
q_num_blocks,
|
| 1611 |
+
q_indices,
|
| 1612 |
+
full_q_num_blocks,
|
| 1613 |
+
full_q_indices,
|
| 1614 |
+
SPARSE_KV_BLOCK_SIZE,
|
| 1615 |
+
SPARSE_Q_BLOCK_SIZE,
|
| 1616 |
+
mask_graph,
|
| 1617 |
+
) = block_mask
|
| 1618 |
+
|
| 1619 |
+
(
|
| 1620 |
+
query,
|
| 1621 |
+
key,
|
| 1622 |
+
value,
|
| 1623 |
+
grad_out,
|
| 1624 |
+
kv_num_blocks,
|
| 1625 |
+
kv_indices,
|
| 1626 |
+
full_kv_num_blocks,
|
| 1627 |
+
full_kv_indices,
|
| 1628 |
+
q_num_blocks,
|
| 1629 |
+
q_indices,
|
| 1630 |
+
full_q_num_blocks,
|
| 1631 |
+
full_q_indices,
|
| 1632 |
+
) = maybe_realize(
|
| 1633 |
+
[
|
| 1634 |
+
query,
|
| 1635 |
+
key,
|
| 1636 |
+
value,
|
| 1637 |
+
grad_out,
|
| 1638 |
+
kv_num_blocks,
|
| 1639 |
+
kv_indices,
|
| 1640 |
+
full_kv_num_blocks,
|
| 1641 |
+
full_kv_indices,
|
| 1642 |
+
q_num_blocks,
|
| 1643 |
+
q_indices,
|
| 1644 |
+
full_q_num_blocks,
|
| 1645 |
+
full_q_indices,
|
| 1646 |
+
]
|
| 1647 |
+
)
|
| 1648 |
+
|
| 1649 |
+
device = query.get_device()
|
| 1650 |
+
dtype = query.get_dtype()
|
| 1651 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
| 1652 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
| 1653 |
+
assert Bq == Bkv, "Batch dimension must match"
|
| 1654 |
+
B = Bq
|
| 1655 |
+
|
| 1656 |
+
kernel_options = dict(kernel_options)
|
| 1657 |
+
kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
|
| 1658 |
+
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
|
| 1659 |
+
kernel_options.setdefault("IS_DIVISIBLE", False)
|
| 1660 |
+
else:
|
| 1661 |
+
kernel_options.setdefault("IS_DIVISIBLE", True)
|
| 1662 |
+
|
| 1663 |
+
fwd_placeholder_inps = [
|
| 1664 |
+
create_placeholder(name, dtype, device)
|
| 1665 |
+
for name, dtype in [
|
| 1666 |
+
("score", dtype),
|
| 1667 |
+
("b", torch.int32),
|
| 1668 |
+
("h", torch.int32),
|
| 1669 |
+
("m", torch.int32),
|
| 1670 |
+
("n", torch.int32),
|
| 1671 |
+
]
|
| 1672 |
+
]
|
| 1673 |
+
fw_subgraph_buffer = build_subgraph_buffer(
|
| 1674 |
+
fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
|
| 1675 |
+
)
|
| 1676 |
+
|
| 1677 |
+
joint_placeholder_inps = fwd_placeholder_inps + [
|
| 1678 |
+
create_placeholder("grad_score_mod", dtype, device)
|
| 1679 |
+
]
|
| 1680 |
+
joint_subgraph_buffer, *_ = build_subgraph_buffer(
|
| 1681 |
+
joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
|
| 1682 |
+
)
|
| 1683 |
+
|
| 1684 |
+
mask_graph_placeholder_inps = [
|
| 1685 |
+
create_placeholder(name, dtype, query.get_device())
|
| 1686 |
+
for name, dtype in [
|
| 1687 |
+
("b", torch.int32),
|
| 1688 |
+
("h", torch.int32),
|
| 1689 |
+
("m", torch.int32),
|
| 1690 |
+
("n", torch.int32),
|
| 1691 |
+
]
|
| 1692 |
+
]
|
| 1693 |
+
mask_graph_buffer = build_subgraph_buffer(
|
| 1694 |
+
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
|
| 1695 |
+
)
|
| 1696 |
+
|
| 1697 |
+
layout_k = FixedLayout(
|
| 1698 |
+
key.get_device(),
|
| 1699 |
+
key.get_dtype(),
|
| 1700 |
+
key.get_size(),
|
| 1701 |
+
key.get_stride(),
|
| 1702 |
+
)
|
| 1703 |
+
|
| 1704 |
+
# Create delta which will is needed for the bwd's kernel
|
| 1705 |
+
grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2))
|
| 1706 |
+
mul_delta = lowerings[aten.mul](out, grad_out)
|
| 1707 |
+
delta = lowerings[aten.sum](mul_delta, axis=-1)
|
| 1708 |
+
delta = lowerings[aten.sub](delta, grad_lse_exp2)
|
| 1709 |
+
delta = ExternKernel.require_contiguous(delta)
|
| 1710 |
+
|
| 1711 |
+
grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta])
|
| 1712 |
+
|
| 1713 |
+
# see NOTE:[TritonTemplates with multiple outputs]
|
| 1714 |
+
grad_query = empty_strided(
|
| 1715 |
+
query.get_size(), query.get_stride(), dtype=dtype, device=device
|
| 1716 |
+
)
|
| 1717 |
+
grad_value = empty_strided(
|
| 1718 |
+
value.get_size(), value.get_stride(), dtype=dtype, device=device
|
| 1719 |
+
)
|
| 1720 |
+
|
| 1721 |
+
kernel_options.setdefault("SM_SCALE", scale)
|
| 1722 |
+
|
| 1723 |
+
# Determine GQA factor
|
| 1724 |
+
gqa_shared_heads = Hq // Hkv
|
| 1725 |
+
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
|
| 1726 |
+
|
| 1727 |
+
# Inside of Triton kernel, only apply partial masking if partial blocks are computed.
|
| 1728 |
+
# full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
|
| 1729 |
+
has_full_blocks = full_kv_num_blocks is not None
|
| 1730 |
+
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
|
| 1731 |
+
if not has_full_blocks:
|
| 1732 |
+
full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = (
|
| 1733 |
+
empty(0, device=query.get_device()) for _ in range(4)
|
| 1734 |
+
)
|
| 1735 |
+
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
|
| 1736 |
+
kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
|
| 1737 |
+
|
| 1738 |
+
choices: List[Any] = []
|
| 1739 |
+
configs: List[Tuple[int, int, int, int]] = []
|
| 1740 |
+
configs.append(_get_default_config_bwd(query))
|
| 1741 |
+
if config.max_autotune:
|
| 1742 |
+
configs.extend(
|
| 1743 |
+
[
|
| 1744 |
+
(BLOCK1, BLOCK2, w, s)
|
| 1745 |
+
for BLOCK1 in [32, 64]
|
| 1746 |
+
for BLOCK2 in [32, 64, 128]
|
| 1747 |
+
for w in [4, 8]
|
| 1748 |
+
for s in [1, 3, 4, 5]
|
| 1749 |
+
if BLOCK2 % BLOCK1 == 0
|
| 1750 |
+
]
|
| 1751 |
+
)
|
| 1752 |
+
|
| 1753 |
+
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
|
| 1754 |
+
if (
|
| 1755 |
+
SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0
|
| 1756 |
+
or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0
|
| 1757 |
+
or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0
|
| 1758 |
+
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
|
| 1759 |
+
):
|
| 1760 |
+
continue
|
| 1761 |
+
|
| 1762 |
+
# Performance tuning
|
| 1763 |
+
kernel_options.setdefault("BLOCK_M1", BLOCK1)
|
| 1764 |
+
kernel_options.setdefault("BLOCK_N1", BLOCK2)
|
| 1765 |
+
kernel_options.setdefault("BLOCK_M2", BLOCK2)
|
| 1766 |
+
kernel_options.setdefault("BLOCK_N2", BLOCK1)
|
| 1767 |
+
# Blocksparse options
|
| 1768 |
+
kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
|
| 1769 |
+
kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
| 1770 |
+
|
| 1771 |
+
flex_attention_backward_template.maybe_append_choice(
|
| 1772 |
+
choices=choices,
|
| 1773 |
+
input_nodes=[
|
| 1774 |
+
query,
|
| 1775 |
+
key,
|
| 1776 |
+
value,
|
| 1777 |
+
logsumexp,
|
| 1778 |
+
delta,
|
| 1779 |
+
grad_out,
|
| 1780 |
+
grad_query,
|
| 1781 |
+
grad_value,
|
| 1782 |
+
kv_num_blocks,
|
| 1783 |
+
kv_indices,
|
| 1784 |
+
q_num_blocks,
|
| 1785 |
+
q_indices,
|
| 1786 |
+
full_kv_num_blocks,
|
| 1787 |
+
full_kv_indices,
|
| 1788 |
+
full_q_num_blocks,
|
| 1789 |
+
full_q_indices,
|
| 1790 |
+
],
|
| 1791 |
+
layout=layout_k, # We use store_output only for grad_key
|
| 1792 |
+
subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
|
| 1793 |
+
mutated_inputs=[grad_query, grad_value],
|
| 1794 |
+
call_sizes=query.get_size() + key.get_size()[1:3],
|
| 1795 |
+
num_stages=num_stages,
|
| 1796 |
+
num_warps=num_warps,
|
| 1797 |
+
**kernel_options,
|
| 1798 |
+
)
|
| 1799 |
+
inputs_for_autotuning = (
|
| 1800 |
+
[
|
| 1801 |
+
query,
|
| 1802 |
+
key,
|
| 1803 |
+
value,
|
| 1804 |
+
logsumexp,
|
| 1805 |
+
delta,
|
| 1806 |
+
grad_out,
|
| 1807 |
+
grad_query,
|
| 1808 |
+
grad_value,
|
| 1809 |
+
kv_num_blocks,
|
| 1810 |
+
kv_indices,
|
| 1811 |
+
q_num_blocks,
|
| 1812 |
+
q_indices,
|
| 1813 |
+
full_kv_num_blocks,
|
| 1814 |
+
full_kv_indices,
|
| 1815 |
+
full_q_num_blocks,
|
| 1816 |
+
full_q_indices,
|
| 1817 |
+
]
|
| 1818 |
+
+ list(score_mod_other_buffers)
|
| 1819 |
+
+ list(mask_mod_other_buffers)
|
| 1820 |
+
)
|
| 1821 |
+
input_gen_fns = {
|
| 1822 |
+
8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks
|
| 1823 |
+
9: create_indices_fake,
|
| 1824 |
+
10: create_num_blocks_fake_generator(q_indices), # q_num_blocks
|
| 1825 |
+
11: create_indices_fake,
|
| 1826 |
+
12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks
|
| 1827 |
+
13: create_indices_fake,
|
| 1828 |
+
14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks
|
| 1829 |
+
15: create_indices_fake,
|
| 1830 |
+
}
|
| 1831 |
+
|
| 1832 |
+
grad_key = autotune_select_algorithm(
|
| 1833 |
+
"flex_attention_backward",
|
| 1834 |
+
choices,
|
| 1835 |
+
inputs_for_autotuning,
|
| 1836 |
+
layout_k,
|
| 1837 |
+
input_gen_fns=input_gen_fns,
|
| 1838 |
+
)
|
| 1839 |
+
return (
|
| 1840 |
+
grad_query,
|
| 1841 |
+
grad_key,
|
| 1842 |
+
grad_value,
|
| 1843 |
+
)
|
.venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
| 3 |
+
from typing import Any, List, Tuple
|
| 4 |
+
|
| 5 |
+
import sympy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch._inductor.virtualized import V
|
| 9 |
+
|
| 10 |
+
from .. import config, ir
|
| 11 |
+
from ..ir import FixedLayout, FlexibleLayout
|
| 12 |
+
from ..lowering import empty, empty_strided, lowerings
|
| 13 |
+
from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
|
| 14 |
+
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
| 15 |
+
from .flex_attention import (
|
| 16 |
+
compute_forward_block_mn,
|
| 17 |
+
compute_forward_inner,
|
| 18 |
+
compute_next_offset_func,
|
| 19 |
+
create_indices_fake,
|
| 20 |
+
create_num_blocks_fake_generator,
|
| 21 |
+
maybe_realize,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
aten = torch.ops.aten
|
| 26 |
+
prims = torch.ops.prims
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
|
| 30 |
+
"""How is this kernel parallelized?
|
| 31 |
+
We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
|
| 32 |
+
Each block is responsible for iterating over blocks of keys and values calculating
|
| 33 |
+
the local output for their tile of keys and values over all full length of query.
|
| 34 |
+
groups of SPLIT_KV blocks then combine their output to produce the final result.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
flex_decoding_template = TritonTemplate(
|
| 41 |
+
name="flex_decoding",
|
| 42 |
+
grid=flex_decoding_grid,
|
| 43 |
+
source=r"""
|
| 44 |
+
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
| 45 |
+
# Sub notation for this kernel:
|
| 46 |
+
# Q: Query, K: Key, V: Value
|
| 47 |
+
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
| 48 |
+
# M: Number of queries, N: Number of keys/values
|
| 49 |
+
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
| 50 |
+
# V_HEAD_DIM: The dimension of the value embeddings
|
| 51 |
+
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
| 52 |
+
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
| 53 |
+
# (Modifiable) Config options:
|
| 54 |
+
# SPLIT_KV: number of blocks K & V are split into
|
| 55 |
+
# TILE_KV: length of each local KV split
|
| 56 |
+
# BLOCK_M: block size that Q is padded along seqlen dim.
|
| 57 |
+
# BLOCK_N: block size of K & V along N dimension.
|
| 58 |
+
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
| 59 |
+
#
|
| 60 |
+
# change of base out of the loop
|
| 61 |
+
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
| 62 |
+
# is not masked out? If so, we can skip an extra safety check
|
| 63 |
+
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
| 64 |
+
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
| 65 |
+
|
| 66 |
+
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
| 67 |
+
#
|
| 68 |
+
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
| 69 |
+
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
| 70 |
+
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
| 71 |
+
#
|
| 72 |
+
#
|
| 73 |
+
# Output: ACC output accumulated across local KV split.
|
| 74 |
+
|
| 75 |
+
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
| 76 |
+
|
| 77 |
+
# Define Q Strides
|
| 78 |
+
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
|
| 79 |
+
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
| 80 |
+
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
| 81 |
+
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
|
| 82 |
+
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
Z = {{size("Q", 0)}}
|
| 86 |
+
HKV = {{size("Q", 1)}}
|
| 87 |
+
G: tl.constexpr = GQA_SHARED_HEADS
|
| 88 |
+
HQ = HKV * G
|
| 89 |
+
Q_LEN = {{size("Q", 3)}}
|
| 90 |
+
KV_LEN = {{size("K", 2)}}
|
| 91 |
+
|
| 92 |
+
MATMUL_PRECISION = Q.dtype.element_ty
|
| 93 |
+
|
| 94 |
+
# Make sure each split is a multiple of BLOCK_N
|
| 95 |
+
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
| 96 |
+
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
| 97 |
+
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
| 98 |
+
|
| 99 |
+
off_z = tl.program_id(0) // HKV
|
| 100 |
+
off_hkv = tl.program_id(0) % HKV
|
| 101 |
+
off_t = tl.program_id(1)
|
| 102 |
+
|
| 103 |
+
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
| 104 |
+
k_offset = off_z * stride_kz + off_hkv * stride_kh
|
| 105 |
+
v_offset = off_z * stride_vz + off_hkv * stride_vh
|
| 106 |
+
|
| 107 |
+
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
| 108 |
+
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
| 109 |
+
|
| 110 |
+
sparse_idx_z = off_z % SPARSE_Z
|
| 111 |
+
# TODO: support masks not broadcasted along the head dimension.
|
| 112 |
+
tl.device_assert(SPARSE_HQ == 1)
|
| 113 |
+
sparse_idx_h = 0
|
| 114 |
+
|
| 115 |
+
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
| 116 |
+
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
| 117 |
+
|
| 118 |
+
# initialize pointer to m and l
|
| 119 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 120 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
| 121 |
+
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
|
| 122 |
+
|
| 123 |
+
# initialize offsets
|
| 124 |
+
tl.device_assert(BLOCK_M % G == 0)
|
| 125 |
+
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
| 126 |
+
off_g = tl.arange(0, G) # [G]
|
| 127 |
+
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 128 |
+
offs_hq = offs_g + off_hkv * G
|
| 129 |
+
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
| 130 |
+
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
| 131 |
+
offs_d = tl.arange(0, QK_HEAD_DIM)
|
| 132 |
+
offs_vd = tl.arange(0, V_HEAD_DIM)
|
| 133 |
+
|
| 134 |
+
# KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
|
| 135 |
+
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
|
| 136 |
+
|
| 137 |
+
# Calculate KV blocks that belong this CTA.
|
| 138 |
+
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
| 139 |
+
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
| 140 |
+
|
| 141 |
+
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
| 142 |
+
|
| 143 |
+
if SAFE_M_BOUNDARY:
|
| 144 |
+
q = tl.load(Q + q_offset + q_range)
|
| 145 |
+
else:
|
| 146 |
+
mask = off_m[None, :, None] < Q_LEN
|
| 147 |
+
q = tl.load(Q + q_offset + q_range, mask)
|
| 148 |
+
|
| 149 |
+
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM])
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 153 |
+
# Apply both score_mod and mask_mod
|
| 154 |
+
|
| 155 |
+
# find first kv block we are loading and the number of blocks we are loading
|
| 156 |
+
kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
| 157 |
+
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
|
| 158 |
+
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
| 159 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 160 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 161 |
+
# first kv block we're loading
|
| 162 |
+
|
| 163 |
+
# last valid block according to sparse mask
|
| 164 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 165 |
+
|
| 166 |
+
K_block_ptr = tl.make_block_ptr(
|
| 167 |
+
base=K + k_offset,
|
| 168 |
+
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
| 169 |
+
strides=(stride_kk, stride_kn),
|
| 170 |
+
offsets=(0, off_n),
|
| 171 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 172 |
+
order=(0, 1)
|
| 173 |
+
)
|
| 174 |
+
V_block_ptr = tl.make_block_ptr(
|
| 175 |
+
base=V + v_offset,
|
| 176 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 177 |
+
strides=(stride_vn, stride_vk),
|
| 178 |
+
offsets=(off_n, 0),
|
| 179 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 180 |
+
order=(1, 0)
|
| 181 |
+
)
|
| 182 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 183 |
+
|
| 184 |
+
acc, l_i, m_i = forward_inner(
|
| 185 |
+
{{gen_argdefs()}},
|
| 186 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 187 |
+
# accumulatd values
|
| 188 |
+
acc, l_i, m_i,
|
| 189 |
+
#offsets
|
| 190 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 191 |
+
#block sparse data
|
| 192 |
+
kv_indices, kv_num_blocks,
|
| 193 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 194 |
+
MATMUL_PRECISION,
|
| 195 |
+
IS_FULL_BLOCKS=False,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 200 |
+
# We know these blocks are guaranteed to be "full", so we don't need to
|
| 201 |
+
# apply mask_mod to them - only score_mod
|
| 202 |
+
if HAS_FULL_BLOCKS:
|
| 203 |
+
kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
|
| 204 |
+
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
|
| 205 |
+
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
| 206 |
+
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
| 207 |
+
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
| 208 |
+
|
| 209 |
+
# last valid block according to sparse mask
|
| 210 |
+
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
| 211 |
+
|
| 212 |
+
K_block_ptr = tl.make_block_ptr(
|
| 213 |
+
base=K + k_offset,
|
| 214 |
+
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
| 215 |
+
strides=(stride_kk, stride_kn),
|
| 216 |
+
offsets=(0, off_n),
|
| 217 |
+
block_shape=(QK_HEAD_DIM, BLOCK_N),
|
| 218 |
+
order=(0, 1)
|
| 219 |
+
)
|
| 220 |
+
V_block_ptr = tl.make_block_ptr(
|
| 221 |
+
base=V + v_offset,
|
| 222 |
+
shape=(KV_LEN, V_HEAD_DIM),
|
| 223 |
+
strides=(stride_vn, stride_vk),
|
| 224 |
+
offsets=(off_n, 0),
|
| 225 |
+
block_shape=(BLOCK_N, V_HEAD_DIM),
|
| 226 |
+
order=(1, 0)
|
| 227 |
+
)
|
| 228 |
+
offs_n = tl.arange(0, BLOCK_N) + off_n
|
| 229 |
+
|
| 230 |
+
acc, l_i, m_i = forward_inner(
|
| 231 |
+
{{gen_argdefs()}},
|
| 232 |
+
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
|
| 233 |
+
# accumulatd values
|
| 234 |
+
acc, l_i, m_i,
|
| 235 |
+
#offsets
|
| 236 |
+
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
| 237 |
+
#block sparse data
|
| 238 |
+
kv_indices, kv_num_blocks,
|
| 239 |
+
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
| 240 |
+
MATMUL_PRECISION,
|
| 241 |
+
IS_FULL_BLOCKS=True,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
m_offset = off_t * stride_mt + off_z * stride_mz
|
| 245 |
+
l_offset = off_t * stride_lt + off_z * stride_lz
|
| 246 |
+
|
| 247 |
+
M_block_ptr = tl.make_block_ptr(
|
| 248 |
+
base=M + m_offset,
|
| 249 |
+
shape=(G, Q_LEN), # (G, M)
|
| 250 |
+
strides=(stride_mh, stride_mm),
|
| 251 |
+
offsets=(off_hkv*G, 0),
|
| 252 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 253 |
+
order=(1, 0)
|
| 254 |
+
)
|
| 255 |
+
L_block_ptr = tl.make_block_ptr(
|
| 256 |
+
base=L + l_offset,
|
| 257 |
+
shape=(G, Q_LEN), # (G, M)
|
| 258 |
+
strides=(stride_lh, stride_lm),
|
| 259 |
+
offsets=(off_hkv*G, 0),
|
| 260 |
+
block_shape=(G, BLOCK_M_PER_HQ),
|
| 261 |
+
order=(1, 0)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
| 265 |
+
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
| 266 |
+
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
| 267 |
+
if SAFE_M_BOUNDARY:
|
| 268 |
+
tl.store(M_block_ptr, m_i)
|
| 269 |
+
tl.store(L_block_ptr, l_i)
|
| 270 |
+
else:
|
| 271 |
+
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
| 272 |
+
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
| 273 |
+
|
| 274 |
+
# -- store output
|
| 275 |
+
idx_z = off_z
|
| 276 |
+
idx_t = off_t
|
| 277 |
+
idx_hq = off_hkv*G + off_g[:, None, None]
|
| 278 |
+
idx_m = off_m[None, :, None]
|
| 279 |
+
idx_d = offs_vd[None, None, :]
|
| 280 |
+
|
| 281 |
+
mask = (idx_m < Q_LEN)
|
| 282 |
+
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
| 283 |
+
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
| 284 |
+
"""
|
| 285 |
+
+ compute_forward_inner
|
| 286 |
+
+ compute_next_offset_func
|
| 287 |
+
+ compute_forward_block_mn,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def get_split_k(B: int, H: int, Mk: int, SM: int = 128) -> int:
|
| 292 |
+
"""Heuristic for the number of splits from xformer"""
|
| 293 |
+
bh = max(B * H, 1) # NOTE: Handle B*h=0 case
|
| 294 |
+
split_k = SM // bh # Each SM should at least get one block.
|
| 295 |
+
split_k = max(split_k, 1)
|
| 296 |
+
|
| 297 |
+
return split_k
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _get_decoding_default_config(key) -> Tuple[int, int, int]:
|
| 301 |
+
dtype = key.get_dtype()
|
| 302 |
+
head_dim = key.get_size()[-1]
|
| 303 |
+
sm_version = torch.cuda.get_device_capability()
|
| 304 |
+
default_config = (64, 2, 1)
|
| 305 |
+
if sm_version >= (9, 0):
|
| 306 |
+
if head_dim > 128 and dtype == torch.float32:
|
| 307 |
+
return default_config
|
| 308 |
+
return (64, 2, 3)
|
| 309 |
+
return default_config
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def create_flex_decoding_kernel(*args, **kwargs):
|
| 313 |
+
(
|
| 314 |
+
query,
|
| 315 |
+
key,
|
| 316 |
+
value,
|
| 317 |
+
block_mask,
|
| 318 |
+
scale,
|
| 319 |
+
kernel_options,
|
| 320 |
+
score_mod_subgraph,
|
| 321 |
+
mask_mod_subgraph,
|
| 322 |
+
score_mod_other_buffers,
|
| 323 |
+
mask_mod_other_buffers,
|
| 324 |
+
) = args
|
| 325 |
+
(
|
| 326 |
+
kv_num_blocks,
|
| 327 |
+
kv_indices,
|
| 328 |
+
full_kv_num_blocks, # full_kv_num_blocks,
|
| 329 |
+
full_kv_indices, # full_kv_indices,
|
| 330 |
+
_, # q_num_blocks
|
| 331 |
+
_, # q_indices
|
| 332 |
+
_, # full_q_num_blocks,
|
| 333 |
+
_, # full_q_indices,
|
| 334 |
+
SPARSE_KV_BLOCK_SIZE,
|
| 335 |
+
_, # SPARSE_Q_BLOCK_SIZE,
|
| 336 |
+
_,
|
| 337 |
+
) = block_mask
|
| 338 |
+
|
| 339 |
+
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
| 340 |
+
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
| 341 |
+
assert Bq == Bkv, "Batch dimension must match"
|
| 342 |
+
B = Bq
|
| 343 |
+
kernel_options = dict(kernel_options)
|
| 344 |
+
|
| 345 |
+
# TODO: Fix flex decoding non-divisible case!
|
| 346 |
+
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
|
| 347 |
+
kernel_options.setdefault("IS_DIVISIBLE", False)
|
| 348 |
+
else:
|
| 349 |
+
kernel_options.setdefault("IS_DIVISIBLE", True)
|
| 350 |
+
|
| 351 |
+
# Calculate GQA head sharing
|
| 352 |
+
gqa_shared_heads = Hq // Hkv
|
| 353 |
+
if not is_power_of_2(gqa_shared_heads):
|
| 354 |
+
raise ValueError(
|
| 355 |
+
"Number of shared query heads sharing the same KV head must be power of 2. "
|
| 356 |
+
)
|
| 357 |
+
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
|
| 358 |
+
|
| 359 |
+
# Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
|
| 360 |
+
has_full_blocks = full_kv_num_blocks is not None
|
| 361 |
+
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
|
| 362 |
+
if not has_full_blocks:
|
| 363 |
+
# Create a plackeholder full block list in case it is empty
|
| 364 |
+
full_kv_num_blocks, full_kv_indices = (
|
| 365 |
+
empty(0, device=query.get_device()) for _ in range(2)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
(
|
| 369 |
+
query,
|
| 370 |
+
key,
|
| 371 |
+
value,
|
| 372 |
+
kv_num_blocks,
|
| 373 |
+
kv_indices,
|
| 374 |
+
full_kv_num_blocks,
|
| 375 |
+
full_kv_indices,
|
| 376 |
+
) = maybe_realize(
|
| 377 |
+
[
|
| 378 |
+
query,
|
| 379 |
+
key,
|
| 380 |
+
value,
|
| 381 |
+
kv_num_blocks,
|
| 382 |
+
kv_indices,
|
| 383 |
+
full_kv_num_blocks,
|
| 384 |
+
full_kv_indices,
|
| 385 |
+
]
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
choices: List[Any] = []
|
| 389 |
+
configs: List[Tuple[int, int, int]] = []
|
| 390 |
+
configs.append(_get_decoding_default_config(key))
|
| 391 |
+
# Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
|
| 392 |
+
if config.max_autotune:
|
| 393 |
+
configs += [
|
| 394 |
+
(64, 2, 2),
|
| 395 |
+
(32, 2, 3),
|
| 396 |
+
(128, 2, 3),
|
| 397 |
+
]
|
| 398 |
+
# TODO: fix autotuning.
|
| 399 |
+
|
| 400 |
+
kernel_options.setdefault("SM_SCALE", scale)
|
| 401 |
+
kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv))
|
| 402 |
+
MAX_SPLIT_KV = kernel_options["SPLIT_KV"]
|
| 403 |
+
|
| 404 |
+
# create config dependent intermediate buffers
|
| 405 |
+
buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim]
|
| 406 |
+
buf_ML_shape = buf_ACC_shape[:-1]
|
| 407 |
+
buf_M = empty_strided(
|
| 408 |
+
buf_ML_shape,
|
| 409 |
+
None,
|
| 410 |
+
dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype
|
| 411 |
+
device=query.get_device(),
|
| 412 |
+
)
|
| 413 |
+
buf_L = empty_strided(
|
| 414 |
+
buf_ML_shape,
|
| 415 |
+
None,
|
| 416 |
+
dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype
|
| 417 |
+
device=query.get_device(),
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
layout_acc = FixedLayout(
|
| 421 |
+
query.get_device(),
|
| 422 |
+
torch.float32,
|
| 423 |
+
buf_ACC_shape,
|
| 424 |
+
FlexibleLayout.contiguous_strides(buf_ACC_shape),
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
|
| 428 |
+
kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
|
| 429 |
+
|
| 430 |
+
kernel_options.setdefault(
|
| 431 |
+
"BLOCK_M",
|
| 432 |
+
(
|
| 433 |
+
# m
|
| 434 |
+
# if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
|
| 435 |
+
# else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
|
| 436 |
+
max(
|
| 437 |
+
next_power_of_2(
|
| 438 |
+
V.graph.sizevars.size_hint(
|
| 439 |
+
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 440 |
+
)
|
| 441 |
+
* gqa_shared_heads
|
| 442 |
+
),
|
| 443 |
+
16,
|
| 444 |
+
)
|
| 445 |
+
),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
query = ir.ExternKernel.realize_input(query)
|
| 449 |
+
stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride()
|
| 450 |
+
|
| 451 |
+
# Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
|
| 452 |
+
gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim)
|
| 453 |
+
gqa_query_stride = (
|
| 454 |
+
stride_b,
|
| 455 |
+
stride_hq * gqa_shared_heads,
|
| 456 |
+
stride_hq,
|
| 457 |
+
stride_seq_len_q,
|
| 458 |
+
stride_qk_head_dim,
|
| 459 |
+
)
|
| 460 |
+
query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
|
| 461 |
+
|
| 462 |
+
V.graph.sizevars.guard_leq(
|
| 463 |
+
seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
kernel_options.setdefault(
|
| 467 |
+
"SAFE_M_BOUNDARY",
|
| 468 |
+
((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0,
|
| 469 |
+
)
|
| 470 |
+
# TODO: This feels sketchy
|
| 471 |
+
kernel_options.setdefault("SAFE_N_BOUNDARY", True)
|
| 472 |
+
|
| 473 |
+
# Note, we don't need to pass in the captured buffers explicitly
|
| 474 |
+
# because they're implicitly added by the score_mod function
|
| 475 |
+
# We do need to explicitly pass it in for autotuning though.
|
| 476 |
+
for BLOCK_N, num_warps, num_stages in configs:
|
| 477 |
+
if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0:
|
| 478 |
+
continue
|
| 479 |
+
|
| 480 |
+
# Performance tuning
|
| 481 |
+
kernel_options.setdefault("BLOCK_N", BLOCK_N)
|
| 482 |
+
kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
|
| 483 |
+
|
| 484 |
+
# Work around https://github.com/pytorch/pytorch/issues/129625
|
| 485 |
+
if num_stages == 2:
|
| 486 |
+
continue
|
| 487 |
+
flex_decoding_template.maybe_append_choice(
|
| 488 |
+
choices=choices,
|
| 489 |
+
input_nodes=[
|
| 490 |
+
query,
|
| 491 |
+
key,
|
| 492 |
+
value,
|
| 493 |
+
buf_M,
|
| 494 |
+
buf_L,
|
| 495 |
+
kv_num_blocks,
|
| 496 |
+
kv_indices,
|
| 497 |
+
full_kv_num_blocks,
|
| 498 |
+
full_kv_indices,
|
| 499 |
+
],
|
| 500 |
+
layout=layout_acc,
|
| 501 |
+
subgraphs=[
|
| 502 |
+
score_mod_subgraph,
|
| 503 |
+
mask_mod_subgraph,
|
| 504 |
+
],
|
| 505 |
+
mutated_inputs=[buf_M, buf_L],
|
| 506 |
+
num_stages=num_stages,
|
| 507 |
+
num_warps=num_warps,
|
| 508 |
+
call_sizes=query.get_size(),
|
| 509 |
+
**kernel_options,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
inputs_for_flex_decoding = (
|
| 513 |
+
[
|
| 514 |
+
query,
|
| 515 |
+
key,
|
| 516 |
+
value,
|
| 517 |
+
buf_M,
|
| 518 |
+
buf_L,
|
| 519 |
+
kv_num_blocks,
|
| 520 |
+
kv_indices,
|
| 521 |
+
full_kv_num_blocks,
|
| 522 |
+
full_kv_indices,
|
| 523 |
+
]
|
| 524 |
+
+ list(score_mod_other_buffers)
|
| 525 |
+
+ list(mask_mod_other_buffers)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
input_gen_fns = {
|
| 529 |
+
5: create_num_blocks_fake_generator(kv_indices),
|
| 530 |
+
6: create_indices_fake,
|
| 531 |
+
7: create_num_blocks_fake_generator(full_kv_indices),
|
| 532 |
+
8: create_indices_fake,
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
buf_ACC = autotune_select_algorithm(
|
| 536 |
+
"flex_decoding",
|
| 537 |
+
choices,
|
| 538 |
+
inputs_for_flex_decoding,
|
| 539 |
+
layout_acc,
|
| 540 |
+
input_gen_fns=input_gen_fns,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Reduction
|
| 544 |
+
|
| 545 |
+
g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
|
| 546 |
+
# See [Note] Handle fully masked out rows:
|
| 547 |
+
# g_M Is the global max among split kv blocks.
|
| 548 |
+
masked_rows = lowerings[aten.eq](g_M, -float("inf"))
|
| 549 |
+
adj_M = lowerings[aten.sub](buf_M, g_M)
|
| 550 |
+
adj_M = lowerings[aten.where](masked_rows, 0, adj_M)
|
| 551 |
+
alpha = lowerings[aten.exp2](adj_M)
|
| 552 |
+
|
| 553 |
+
buf_L = lowerings[aten.mul](buf_L, alpha)
|
| 554 |
+
g_L = lowerings[aten.sum](buf_L, axis=1)
|
| 555 |
+
masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
|
| 556 |
+
g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
|
| 557 |
+
logsumexp = lowerings[aten.log2](g_L)
|
| 558 |
+
logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
|
| 559 |
+
|
| 560 |
+
alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
|
| 561 |
+
buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
|
| 562 |
+
output = lowerings[aten.sum](buf_ACC, axis=1)
|
| 563 |
+
L_unseq = lowerings[aten.unsqueeze](g_L, 3)
|
| 564 |
+
output = lowerings[aten.div](output, L_unseq)
|
| 565 |
+
output = lowerings[prims.convert_element_type](output, query.get_dtype())
|
| 566 |
+
|
| 567 |
+
return (
|
| 568 |
+
output,
|
| 569 |
+
logsumexp,
|
| 570 |
+
)
|
.venv/Lib/site-packages/torch/_inductor/kernel/mm.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
|
| 8 |
+
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
| 9 |
+
AHContext,
|
| 10 |
+
context_add_strides,
|
| 11 |
+
context_add_using_tf32,
|
| 12 |
+
get_mixedmm_precondition,
|
| 13 |
+
mixed_mm_operations,
|
| 14 |
+
mm_operations,
|
| 15 |
+
)
|
| 16 |
+
from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
|
| 17 |
+
from torch._inductor.virtualized import V
|
| 18 |
+
|
| 19 |
+
from .. import config as inductor_config
|
| 20 |
+
from ..codegen.common import BackendFeature
|
| 21 |
+
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
| 22 |
+
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
| 23 |
+
from ..codegen.wrapper import WrapperCodeGen
|
| 24 |
+
from ..ir import FlexibleLayout, is_triton
|
| 25 |
+
from ..lowering import register_lowering
|
| 26 |
+
from ..select_algorithm import (
|
| 27 |
+
autotune_select_algorithm,
|
| 28 |
+
ExternKernelChoice,
|
| 29 |
+
NoValidChoicesError,
|
| 30 |
+
TritonTemplate,
|
| 31 |
+
)
|
| 32 |
+
from ..utils import (
|
| 33 |
+
get_gpu_shared_memory,
|
| 34 |
+
use_aten_gemm_kernels,
|
| 35 |
+
use_ck_template,
|
| 36 |
+
use_cpp_packed_gemm_template,
|
| 37 |
+
use_cutlass_template,
|
| 38 |
+
use_max_autotune,
|
| 39 |
+
use_triton_template,
|
| 40 |
+
)
|
| 41 |
+
from .mm_common import (
|
| 42 |
+
addmm_epilogue,
|
| 43 |
+
extra_mm_configs,
|
| 44 |
+
int8_mm_configs,
|
| 45 |
+
mixed_mm_configs,
|
| 46 |
+
mm_args,
|
| 47 |
+
mm_configs,
|
| 48 |
+
mm_grid,
|
| 49 |
+
mm_options,
|
| 50 |
+
triton_config,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
log = logging.getLogger(__name__)
|
| 55 |
+
aten = torch.ops.aten
|
| 56 |
+
|
| 57 |
+
mm_template = TritonTemplate(
|
| 58 |
+
name="mm",
|
| 59 |
+
grid=mm_grid,
|
| 60 |
+
source=r"""
|
| 61 |
+
{{def_kernel("A", "B")}}
|
| 62 |
+
M = {{size("A", 0)}}
|
| 63 |
+
N = {{size("B", 1)}}
|
| 64 |
+
K = {{size("A", 1)}}
|
| 65 |
+
if M * N == 0:
|
| 66 |
+
# early exit due to zero-size input(s)
|
| 67 |
+
return
|
| 68 |
+
stride_am = {{stride("A", 0)}}
|
| 69 |
+
stride_ak = {{stride("A", 1)}}
|
| 70 |
+
stride_bk = {{stride("B", 0)}}
|
| 71 |
+
stride_bn = {{stride("B", 1)}}
|
| 72 |
+
|
| 73 |
+
# based on triton.ops.matmul
|
| 74 |
+
pid = tl.program_id(0)
|
| 75 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 76 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 77 |
+
|
| 78 |
+
# re-order program ID for better L2 performance
|
| 79 |
+
width = GROUP_M * grid_n
|
| 80 |
+
group_id = pid // width
|
| 81 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 82 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 83 |
+
pid_n = (pid % width) // (group_size)
|
| 84 |
+
|
| 85 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 86 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 87 |
+
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
|
| 88 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 89 |
+
else:
|
| 90 |
+
ram = rm % M
|
| 91 |
+
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
|
| 92 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 93 |
+
else:
|
| 94 |
+
rbn = rn % N
|
| 95 |
+
rk = tl.arange(0, BLOCK_K)
|
| 96 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 97 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 98 |
+
|
| 99 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 100 |
+
for k in range(K, 0, -BLOCK_K):
|
| 101 |
+
if EVEN_K:
|
| 102 |
+
a = tl.load(A)
|
| 103 |
+
b = tl.load(B)
|
| 104 |
+
else:
|
| 105 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 106 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 107 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 108 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 109 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 110 |
+
A += BLOCK_K * stride_ak
|
| 111 |
+
B += BLOCK_K * stride_bk
|
| 112 |
+
|
| 113 |
+
# rematerialize rm and rn to save registers
|
| 114 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 115 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 116 |
+
idx_m = rm[:, None]
|
| 117 |
+
idx_n = rn[None, :]
|
| 118 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 119 |
+
|
| 120 |
+
# inductor generates a suffix
|
| 121 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 122 |
+
""",
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
|
| 126 |
+
|
| 127 |
+
aten_addmm = ExternKernelChoice(
|
| 128 |
+
torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
|
| 132 |
+
|
| 133 |
+
aten__sparse_semi_structured_mm = ExternKernelChoice(
|
| 134 |
+
torch._sparse_semi_structured_mm,
|
| 135 |
+
"at::_sparse_semi_structured_mm",
|
| 136 |
+
has_out_variant=False,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _is_int8_mat(mat):
|
| 141 |
+
return mat.get_dtype() in (torch.int8, torch.uint8)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
|
| 145 |
+
"""
|
| 146 |
+
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
|
| 147 |
+
kernel under the hood. There are a few shapes where this is slower,
|
| 148 |
+
but they are rare.
|
| 149 |
+
"""
|
| 150 |
+
if inp.stride(0) == 0 or inp.size(0) == 1:
|
| 151 |
+
return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
|
| 152 |
+
return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@register_lowering(aten.mm, type_promotion_kind=None)
|
| 159 |
+
def tuned_mm(mat1, mat2, *, layout=None):
|
| 160 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 161 |
+
name = "mm"
|
| 162 |
+
|
| 163 |
+
aten_layout = layout
|
| 164 |
+
if not use_max_autotune():
|
| 165 |
+
aten_layout = FlexibleLayout(
|
| 166 |
+
device=layout.device, dtype=layout.dtype, size=layout.size
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# options to tune from
|
| 170 |
+
choices = (
|
| 171 |
+
[aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
|
| 172 |
+
)
|
| 173 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 174 |
+
if is_nonzero and use_triton_template(layout):
|
| 175 |
+
for config in mm_configs(m, n, k):
|
| 176 |
+
mm_template.maybe_append_choice(
|
| 177 |
+
choices,
|
| 178 |
+
input_nodes=(mat1, mat2),
|
| 179 |
+
layout=layout,
|
| 180 |
+
**mm_options(config, m, n, k, layout),
|
| 181 |
+
)
|
| 182 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 183 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
|
| 184 |
+
|
| 185 |
+
if is_nonzero and use_ck_template(layout, m, n, k):
|
| 186 |
+
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
|
| 187 |
+
|
| 188 |
+
if use_cpp_packed_gemm_template(layout, mat1, mat2):
|
| 189 |
+
CppPackedGemmTemplate.add_choices(
|
| 190 |
+
choices,
|
| 191 |
+
layout,
|
| 192 |
+
[mat1, mat2],
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
input_nodes = [mat1, mat2]
|
| 196 |
+
if (
|
| 197 |
+
is_nonzero
|
| 198 |
+
and use_triton_template(layout)
|
| 199 |
+
and torch._inductor.config.run_autoheuristic(name)
|
| 200 |
+
and is_triton(mat1)
|
| 201 |
+
):
|
| 202 |
+
always_included = []
|
| 203 |
+
if use_aten_gemm_kernels():
|
| 204 |
+
always_included.append("extern_mm")
|
| 205 |
+
num_choices_before_extra_configs = len(choices)
|
| 206 |
+
for config in extra_mm_configs(m, n, k):
|
| 207 |
+
mm_template.maybe_append_choice(
|
| 208 |
+
choices,
|
| 209 |
+
input_nodes=(mat1, mat2),
|
| 210 |
+
layout=layout,
|
| 211 |
+
**mm_options(config, m, n, k, layout),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# using AutoHeuristic for ranking
|
| 215 |
+
ah_choices = mm_autoheuristic(
|
| 216 |
+
mat1,
|
| 217 |
+
mat2,
|
| 218 |
+
m,
|
| 219 |
+
n,
|
| 220 |
+
k,
|
| 221 |
+
choices,
|
| 222 |
+
name,
|
| 223 |
+
input_nodes,
|
| 224 |
+
mm_operations(),
|
| 225 |
+
None,
|
| 226 |
+
top_k=10,
|
| 227 |
+
always_included=always_included,
|
| 228 |
+
)
|
| 229 |
+
if not torch._inductor.config.collect_autoheuristic(name):
|
| 230 |
+
# if we are collecting data, we do not want to modify choices
|
| 231 |
+
if ah_choices is not None and len(ah_choices) > 0:
|
| 232 |
+
# the order in which autoheuristic returns choices is not the same as
|
| 233 |
+
# as the order of choices, which affects things like epilogue fusion.
|
| 234 |
+
# once epilogue fusion benchmarks choices in sorted order, I think we can
|
| 235 |
+
# just use the order returned by autoheuristic
|
| 236 |
+
choices = [choice for choice in choices if choice in ah_choices]
|
| 237 |
+
else:
|
| 238 |
+
choices = choices[:num_choices_before_extra_configs]
|
| 239 |
+
|
| 240 |
+
if (
|
| 241 |
+
len(choices) == 0
|
| 242 |
+
and not use_aten_gemm_kernels()
|
| 243 |
+
and inductor_config.autotune_fallback_to_aten
|
| 244 |
+
):
|
| 245 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 246 |
+
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
|
| 250 |
+
except NoValidChoicesError:
|
| 251 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 252 |
+
raise
|
| 253 |
+
log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
|
| 254 |
+
return aten_mm.bind((mat1, mat2), aten_layout).output_node()
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _is_static_problem(inputs_tensors, layout):
|
| 258 |
+
# checks whether all input tensors and the output layout
|
| 259 |
+
# have a static shape by attempting to convert the dimensions
|
| 260 |
+
# to int
|
| 261 |
+
static_shape = True
|
| 262 |
+
static_size = WrapperCodeGen.statically_known_list_of_ints_or_none(layout.size)
|
| 263 |
+
if static_size is None:
|
| 264 |
+
nonzero = True
|
| 265 |
+
for s in layout.size:
|
| 266 |
+
sz = WrapperCodeGen.statically_known_int_or_none(s)
|
| 267 |
+
if sz is not None and sz == 0:
|
| 268 |
+
nonzero = False
|
| 269 |
+
break
|
| 270 |
+
return False, nonzero
|
| 271 |
+
numel = 1
|
| 272 |
+
for dim in static_size:
|
| 273 |
+
numel *= dim
|
| 274 |
+
nonzero = numel > 0
|
| 275 |
+
return static_shape, nonzero
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@register_lowering(aten._int_mm, type_promotion_kind=None)
|
| 279 |
+
def tuned_int_mm(mat1, mat2, *, layout=None):
|
| 280 |
+
m, n, k, layout, mat1, mat2 = mm_args(
|
| 281 |
+
mat1, mat2, layout=layout, out_dtype=torch.int32
|
| 282 |
+
)
|
| 283 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 284 |
+
use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
|
| 285 |
+
|
| 286 |
+
choices = (
|
| 287 |
+
[aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# TODO: Re-enable eager mode implementation once cuBLAS is fixed
|
| 291 |
+
if use_cutlass or use_triton_template(layout, enable_int32=True):
|
| 292 |
+
choices = []
|
| 293 |
+
|
| 294 |
+
if use_cutlass:
|
| 295 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
| 296 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 297 |
+
)
|
| 298 |
+
if is_nonzero and use_triton_template(layout, enable_int32=True):
|
| 299 |
+
for config in int8_mm_configs(m, n, k):
|
| 300 |
+
mm_template.maybe_append_choice(
|
| 301 |
+
choices,
|
| 302 |
+
input_nodes=(mat1, mat2),
|
| 303 |
+
layout=layout,
|
| 304 |
+
**mm_options(config, m, n, k, layout),
|
| 305 |
+
)
|
| 306 |
+
if len(choices) == 0:
|
| 307 |
+
log.warning(
|
| 308 |
+
"No choices for integer GEMM avaialbe using configured backends, using ATen backend as fallback"
|
| 309 |
+
)
|
| 310 |
+
choices = [aten__int_mm.bind((mat1, mat2), layout)]
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
|
| 314 |
+
except NoValidChoicesError:
|
| 315 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 316 |
+
raise
|
| 317 |
+
log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
|
| 318 |
+
choices = [aten__int_mm.bind((mat1, mat2), layout)]
|
| 319 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@register_lowering(aten.addmm, type_promotion_kind=None)
|
| 323 |
+
def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
|
| 324 |
+
ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
|
| 325 |
+
m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
|
| 326 |
+
static_shape, is_nonzero = _is_static_problem([inp, mat1, mat2], layout)
|
| 327 |
+
if (not is_nonzero) or (not use_max_autotune()):
|
| 328 |
+
# Use a FlexibleLayout if we are not autotuning.
|
| 329 |
+
# This allows padding strides for the output.
|
| 330 |
+
from torch._inductor.ir import FixedLayout, FlexibleLayout
|
| 331 |
+
|
| 332 |
+
if isinstance(layout, FixedLayout):
|
| 333 |
+
layout = FlexibleLayout(
|
| 334 |
+
device=layout.device, dtype=layout.dtype, size=layout.size
|
| 335 |
+
)
|
| 336 |
+
choices = (
|
| 337 |
+
[
|
| 338 |
+
aten_addmm.bind(
|
| 339 |
+
(inp, mat1, mat2),
|
| 340 |
+
layout,
|
| 341 |
+
alpha=alpha,
|
| 342 |
+
beta=beta,
|
| 343 |
+
)
|
| 344 |
+
]
|
| 345 |
+
if use_aten_gemm_kernels()
|
| 346 |
+
else []
|
| 347 |
+
)
|
| 348 |
+
return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
|
| 349 |
+
|
| 350 |
+
choices = (
|
| 351 |
+
[
|
| 352 |
+
aten_addmm.bind(
|
| 353 |
+
(inp_expanded, mat1, mat2),
|
| 354 |
+
layout,
|
| 355 |
+
alpha=alpha,
|
| 356 |
+
beta=beta,
|
| 357 |
+
)
|
| 358 |
+
]
|
| 359 |
+
if use_aten_gemm_kernels()
|
| 360 |
+
else []
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if (
|
| 364 |
+
use_aten_gemm_kernels()
|
| 365 |
+
and inp_expanded.get_stride()[0] == 0
|
| 366 |
+
and inp_expanded.get_device().type == "cuda"
|
| 367 |
+
and inductor_config.triton.autotune_cublasLt
|
| 368 |
+
):
|
| 369 |
+
# unexpand inp to make sure fused addmm from cublasLt is used
|
| 370 |
+
choices.insert(
|
| 371 |
+
0,
|
| 372 |
+
aten_bias_addmm.bind(
|
| 373 |
+
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
| 374 |
+
),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if is_nonzero and use_triton_template(layout):
|
| 378 |
+
for config in mm_configs(m, n, k):
|
| 379 |
+
mm_template.maybe_append_choice(
|
| 380 |
+
choices,
|
| 381 |
+
input_nodes=(inp_expanded, mat1, mat2),
|
| 382 |
+
layout=layout,
|
| 383 |
+
**mm_options(config, m, n, k, layout),
|
| 384 |
+
prefix_args=1,
|
| 385 |
+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 389 |
+
# Filter out a known cause of CUDA illegal memory access errors
|
| 390 |
+
# broadcasting on the last dim of the bias term seems not to be working
|
| 391 |
+
# in the linear GEMM epilogue used by addmm.
|
| 392 |
+
if (
|
| 393 |
+
WrapperCodeGen.statically_known_int_or_none(inp_expanded.layout.stride[-1])
|
| 394 |
+
!= 0
|
| 395 |
+
):
|
| 396 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
| 397 |
+
choices,
|
| 398 |
+
layout,
|
| 399 |
+
[mat1, mat2, inp_expanded],
|
| 400 |
+
alpha=alpha,
|
| 401 |
+
beta=beta,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if is_nonzero and use_ck_template(layout, m, n, k):
|
| 405 |
+
CKGemmTemplate.add_ck_gemm_choices(
|
| 406 |
+
choices,
|
| 407 |
+
layout,
|
| 408 |
+
[mat1, mat2, inp_expanded],
|
| 409 |
+
alpha=alpha,
|
| 410 |
+
beta=beta,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
if use_cpp_packed_gemm_template(layout, mat1, mat2):
|
| 414 |
+
CppPackedGemmTemplate.add_choices(
|
| 415 |
+
choices,
|
| 416 |
+
layout,
|
| 417 |
+
[inp_expanded, mat1, mat2],
|
| 418 |
+
alpha=alpha,
|
| 419 |
+
beta=beta,
|
| 420 |
+
has_bias=True,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
add_aten_fallback = False
|
| 424 |
+
if len(choices) == 0:
|
| 425 |
+
log.warning("No choices for GEMM, using ATen backend as fallback")
|
| 426 |
+
add_aten_fallback = True
|
| 427 |
+
|
| 428 |
+
if add_aten_fallback:
|
| 429 |
+
choices.append(
|
| 430 |
+
aten_addmm.bind(
|
| 431 |
+
(inp_expanded, mat1, mat2),
|
| 432 |
+
layout,
|
| 433 |
+
ordered_kwargs_for_cpp_kernel,
|
| 434 |
+
alpha=alpha,
|
| 435 |
+
beta=beta,
|
| 436 |
+
)
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if (
|
| 440 |
+
inp_expanded.get_stride()[0] == 0
|
| 441 |
+
and inp_expanded.get_device().type == "cuda"
|
| 442 |
+
and inductor_config.triton.autotune_cublasLt
|
| 443 |
+
):
|
| 444 |
+
# unexpand inp to make sure fused addmm from cublasLt is used
|
| 445 |
+
choices.insert(
|
| 446 |
+
0,
|
| 447 |
+
aten_bias_addmm.bind(
|
| 448 |
+
(inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
|
| 449 |
+
),
|
| 450 |
+
)
|
| 451 |
+
try:
|
| 452 |
+
return autotune_select_algorithm(
|
| 453 |
+
"addmm", choices, [inp_expanded, mat1, mat2], layout
|
| 454 |
+
)
|
| 455 |
+
except NoValidChoicesError:
|
| 456 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 457 |
+
raise
|
| 458 |
+
log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
|
| 459 |
+
fallback_choice = aten_addmm.bind(
|
| 460 |
+
(inp, mat1, mat2),
|
| 461 |
+
layout,
|
| 462 |
+
ordered_kwargs_for_cpp_kernel,
|
| 463 |
+
alpha=alpha,
|
| 464 |
+
beta=beta,
|
| 465 |
+
)
|
| 466 |
+
return fallback_choice.output_node()
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
|
| 470 |
+
def tuned_sparse_semi_structured_mm(
|
| 471 |
+
mat1, mat1_meta, mat2, *, out_dtype=None, layout=None
|
| 472 |
+
):
|
| 473 |
+
from torch._inductor.select_algorithm import realize_inputs
|
| 474 |
+
|
| 475 |
+
mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2)
|
| 476 |
+
m1, k1 = mat1.get_size()
|
| 477 |
+
m2, _ = mat1_meta.get_size()
|
| 478 |
+
k2, n = mat2.get_size()
|
| 479 |
+
m = V.graph.sizevars.guard_equals(m1, m2)
|
| 480 |
+
k = V.graph.sizevars.guard_equals(2 * k1, k2)
|
| 481 |
+
|
| 482 |
+
if layout is None:
|
| 483 |
+
from torch._inductor.ir import FixedLayout
|
| 484 |
+
|
| 485 |
+
layout = FixedLayout(
|
| 486 |
+
mat2.get_device(),
|
| 487 |
+
out_dtype if out_dtype else mat2.get_dtype(),
|
| 488 |
+
[m, n],
|
| 489 |
+
[n, 1],
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
assert out_dtype is None, "out_dtype is ignored if layout is specified."
|
| 493 |
+
|
| 494 |
+
choices = (
|
| 495 |
+
[
|
| 496 |
+
aten__sparse_semi_structured_mm.bind(
|
| 497 |
+
(mat1, mat1_meta, mat2), layout, out_dtype=out_dtype
|
| 498 |
+
)
|
| 499 |
+
]
|
| 500 |
+
if use_aten_gemm_kernels()
|
| 501 |
+
else []
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if m * n != 0 and use_cutlass_template(layout, m, n, k):
|
| 505 |
+
CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
|
| 506 |
+
choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
return autotune_select_algorithm(
|
| 510 |
+
"sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def fallback_mixed_mm(mat1, mat2, *, out):
|
| 515 |
+
return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
@functools.lru_cache(None)
|
| 522 |
+
def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
|
| 523 |
+
props = torch.cuda.get_device_properties(index or 0)
|
| 524 |
+
return props.major <= 7
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def dims_are_int(dims):
|
| 528 |
+
return all(isinstance(dim, int) for dim in dims)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
|
| 532 |
+
m, n, k = get_size_hints(mat1, mat2, m, n, k)
|
| 533 |
+
if not dims_are_int([m, n, k]):
|
| 534 |
+
return None
|
| 535 |
+
|
| 536 |
+
if mat1.dtype != torch.float16:
|
| 537 |
+
return None
|
| 538 |
+
|
| 539 |
+
# only use heuristic if we are running on an A100
|
| 540 |
+
# torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
|
| 541 |
+
# which does not have enough shared memory for one of the configs
|
| 542 |
+
if (
|
| 543 |
+
not torch.cuda.get_device_capability() >= (8, 0)
|
| 544 |
+
) or get_gpu_shared_memory() != 166912:
|
| 545 |
+
return None
|
| 546 |
+
|
| 547 |
+
if m == 1 and (n % 16 != 0 or k % 16 != 0):
|
| 548 |
+
return None
|
| 549 |
+
|
| 550 |
+
if m <= 16 and n >= 4096 and k >= 4096:
|
| 551 |
+
return triton_config(
|
| 552 |
+
BLOCK_M=16,
|
| 553 |
+
BLOCK_N=64,
|
| 554 |
+
BLOCK_K=128,
|
| 555 |
+
num_stages=5,
|
| 556 |
+
num_warps=4,
|
| 557 |
+
)
|
| 558 |
+
elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
|
| 559 |
+
return triton_config(
|
| 560 |
+
BLOCK_M=32,
|
| 561 |
+
BLOCK_N=32,
|
| 562 |
+
BLOCK_K=128,
|
| 563 |
+
num_stages=5,
|
| 564 |
+
num_warps=4,
|
| 565 |
+
)
|
| 566 |
+
elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
|
| 567 |
+
return triton_config(
|
| 568 |
+
BLOCK_M=64,
|
| 569 |
+
BLOCK_N=32,
|
| 570 |
+
BLOCK_K=128,
|
| 571 |
+
num_stages=5,
|
| 572 |
+
num_warps=4,
|
| 573 |
+
)
|
| 574 |
+
return None
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def mm_autoheuristic(
|
| 578 |
+
mat1,
|
| 579 |
+
mat2,
|
| 580 |
+
m,
|
| 581 |
+
n,
|
| 582 |
+
k,
|
| 583 |
+
choices,
|
| 584 |
+
name,
|
| 585 |
+
input_nodes,
|
| 586 |
+
ops,
|
| 587 |
+
precondition,
|
| 588 |
+
top_k: Optional[int] = None,
|
| 589 |
+
always_included=None,
|
| 590 |
+
):
|
| 591 |
+
m, n, k = get_size_hints(mat1, mat2, m, n, k)
|
| 592 |
+
if not dims_are_int([m, n, k]):
|
| 593 |
+
return None
|
| 594 |
+
mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2)
|
| 595 |
+
|
| 596 |
+
def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride):
|
| 597 |
+
context = AHContext()
|
| 598 |
+
context.add_feature("m", m)
|
| 599 |
+
context.add_feature("k", k)
|
| 600 |
+
context.add_feature("n", n)
|
| 601 |
+
context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True)
|
| 602 |
+
context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True)
|
| 603 |
+
context_add_strides(context, "mat1", mat1_stride)
|
| 604 |
+
context_add_strides(context, "mat2", mat2_stride)
|
| 605 |
+
context.add_feature(
|
| 606 |
+
"mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True
|
| 607 |
+
)
|
| 608 |
+
context.add_feature(
|
| 609 |
+
"mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True
|
| 610 |
+
)
|
| 611 |
+
if name == "mm":
|
| 612 |
+
# for mixed_mm, we only consider fp16
|
| 613 |
+
context_add_using_tf32(context, mat1.layout.dtype)
|
| 614 |
+
return context
|
| 615 |
+
|
| 616 |
+
def fallback():
|
| 617 |
+
return None
|
| 618 |
+
|
| 619 |
+
context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride)
|
| 620 |
+
autoheuristic = AutoHeuristicSelectAlgorithm(
|
| 621 |
+
fallback=fallback,
|
| 622 |
+
choices=choices,
|
| 623 |
+
input_nodes=input_nodes,
|
| 624 |
+
context=context,
|
| 625 |
+
name=name,
|
| 626 |
+
augment_context=ops,
|
| 627 |
+
precondition=precondition,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
if top_k is not None:
|
| 631 |
+
# TODO: is there a cleaner way to ensure aten.mm is always included?
|
| 632 |
+
return autoheuristic.get_top_k_choices_caller(
|
| 633 |
+
top_k, always_included=always_included
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
return autoheuristic.get_choice_caller()
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def get_size_hints(mat1, mat2, m, n, k):
|
| 640 |
+
if not isinstance(m, int) or not isinstance(k, int):
|
| 641 |
+
(m, k) = V.graph.sizevars.size_hints(
|
| 642 |
+
mat1.get_size(),
|
| 643 |
+
fallback=torch._inductor.config.unbacked_symint_fallback,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
if not isinstance(n, int) or not isinstance(k, int):
|
| 647 |
+
(k, n) = V.graph.sizevars.size_hints(
|
| 648 |
+
mat2.get_size(),
|
| 649 |
+
fallback=torch._inductor.config.unbacked_symint_fallback,
|
| 650 |
+
)
|
| 651 |
+
return m, n, k
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def get_size_hints_strides(mat1, mat2):
|
| 655 |
+
mat1_stride = mat1.layout.stride
|
| 656 |
+
mat2_stride = mat2.layout.stride
|
| 657 |
+
strides = [mat1_stride, mat2_stride]
|
| 658 |
+
strides_hints = []
|
| 659 |
+
for stride in strides:
|
| 660 |
+
if not isinstance(stride, int):
|
| 661 |
+
stride = V.graph.sizevars.size_hints(
|
| 662 |
+
stride,
|
| 663 |
+
fallback=torch._inductor.config.unbacked_symint_fallback,
|
| 664 |
+
)
|
| 665 |
+
strides_hints.append(stride)
|
| 666 |
+
return strides_hints[0], strides_hints[1]
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def tuned_mixed_mm(mat1, mat2, mat2_dtype):
|
| 670 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
|
| 671 |
+
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
|
| 672 |
+
|
| 673 |
+
fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
|
| 674 |
+
|
| 675 |
+
choices = [fallback]
|
| 676 |
+
|
| 677 |
+
# can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
|
| 678 |
+
skip_triton = (
|
| 679 |
+
(
|
| 680 |
+
mat1.layout.dtype != torch.float32
|
| 681 |
+
and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed())
|
| 682 |
+
)
|
| 683 |
+
or _is_sm7x_or_older_gpu(layout.device.index)
|
| 684 |
+
or inductor_config.mixed_mm_choice == "aten"
|
| 685 |
+
or not V.graph.has_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
|
| 686 |
+
or (
|
| 687 |
+
mat1.layout.dtype == torch.float32 and torch.backends.cuda.matmul.allow_tf32
|
| 688 |
+
)
|
| 689 |
+
or (mat1.layout.dtype == torch.bfloat16 and mat2.layout.dtype == torch.uint8)
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
if inductor_config.mixed_mm_choice == "triton":
|
| 693 |
+
choices = []
|
| 694 |
+
|
| 695 |
+
if not skip_triton:
|
| 696 |
+
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
| 697 |
+
if static_shape and inductor_config.mixed_mm_choice == "heuristic":
|
| 698 |
+
choices = []
|
| 699 |
+
config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout)
|
| 700 |
+
if config is not None:
|
| 701 |
+
mm_template.maybe_append_choice(
|
| 702 |
+
choices,
|
| 703 |
+
input_nodes=(mat1, mat2),
|
| 704 |
+
layout=layout,
|
| 705 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 706 |
+
)
|
| 707 |
+
choices.append(fallback)
|
| 708 |
+
|
| 709 |
+
has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
|
| 710 |
+
for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
|
| 711 |
+
mm_template.maybe_append_choice(
|
| 712 |
+
choices,
|
| 713 |
+
input_nodes=(mat1, mat2),
|
| 714 |
+
layout=layout,
|
| 715 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
|
| 719 |
+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
|
| 720 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 721 |
+
)
|
| 722 |
+
CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
|
| 723 |
+
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
if skip_triton and not choices:
|
| 727 |
+
choices = [fallback]
|
| 728 |
+
|
| 729 |
+
name = "mixed_mm"
|
| 730 |
+
input_nodes = [mat1, mat2]
|
| 731 |
+
if torch._inductor.config.run_autoheuristic(name):
|
| 732 |
+
choice = mm_autoheuristic(
|
| 733 |
+
mat1,
|
| 734 |
+
mat2,
|
| 735 |
+
m,
|
| 736 |
+
n,
|
| 737 |
+
k,
|
| 738 |
+
choices,
|
| 739 |
+
name,
|
| 740 |
+
input_nodes,
|
| 741 |
+
mixed_mm_operations(),
|
| 742 |
+
get_mixedmm_precondition,
|
| 743 |
+
)
|
| 744 |
+
if (
|
| 745 |
+
not skip_triton
|
| 746 |
+
and inductor_config.mixed_mm_choice == "heuristic"
|
| 747 |
+
and choice is not None
|
| 748 |
+
):
|
| 749 |
+
choices.insert(0, choice)
|
| 750 |
+
return autotune_select_algorithm(name, choices, input_nodes, layout)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# This op is a special case of the int_mm op which we use based on the pattern
|
| 754 |
+
# _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
|
| 755 |
+
# realization of the int32 _int_mm output by forcing fusion with the mul op.
|
| 756 |
+
# This is only used when config.force_fuse_int_mm_with_mul = True
|
| 757 |
+
def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
|
| 758 |
+
out_dtype = (
|
| 759 |
+
torch.promote_types(mat3.get_dtype(), torch.int32)
|
| 760 |
+
if out_dtype is None
|
| 761 |
+
else out_dtype
|
| 762 |
+
)
|
| 763 |
+
m, n, k, layout, mat1, mat2, mat3 = mm_args(
|
| 764 |
+
mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
|
| 765 |
+
)
|
| 766 |
+
choices: List[Dict[Any, Any]] = []
|
| 767 |
+
for config in int8_mm_configs(m, n, k):
|
| 768 |
+
mm_template.maybe_append_choice(
|
| 769 |
+
choices,
|
| 770 |
+
input_nodes=(mat1, mat2, mat3),
|
| 771 |
+
layout=layout,
|
| 772 |
+
**dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
|
| 773 |
+
suffix_args=1,
|
| 774 |
+
epilogue_fn=V.ops.mul,
|
| 775 |
+
)
|
| 776 |
+
return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
|
.venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
from typing import cast, List, Tuple
|
| 6 |
+
|
| 7 |
+
import sympy
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch._inductor.select_algorithm import realize_inputs
|
| 11 |
+
from torch._inductor.virtualized import V
|
| 12 |
+
|
| 13 |
+
from .. import config as inductor_config
|
| 14 |
+
from ..runtime.runtime_utils import next_power_of_2
|
| 15 |
+
from ..utils import ceildiv as cdiv
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def triton_config(num_stages, num_warps, **kwargs):
|
| 22 |
+
from triton import Config
|
| 23 |
+
|
| 24 |
+
return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def filtered_configs(
|
| 28 |
+
m: int,
|
| 29 |
+
n: int,
|
| 30 |
+
k: int,
|
| 31 |
+
configs: List[Tuple[int, int, int, int, int]],
|
| 32 |
+
has_int8_tensor=False,
|
| 33 |
+
):
|
| 34 |
+
"""Heuristic to shrink configs when they are bigger than the input size"""
|
| 35 |
+
|
| 36 |
+
min_block_size = 16
|
| 37 |
+
# block_k=16 seems to be causing issues
|
| 38 |
+
# see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
|
| 39 |
+
min_block_size_k = 32 if has_int8_tensor else 16
|
| 40 |
+
m = max(
|
| 41 |
+
next_power_of_2(
|
| 42 |
+
V.graph.sizevars.size_hint(
|
| 43 |
+
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 44 |
+
)
|
| 45 |
+
),
|
| 46 |
+
min_block_size,
|
| 47 |
+
)
|
| 48 |
+
n = max(
|
| 49 |
+
next_power_of_2(
|
| 50 |
+
V.graph.sizevars.size_hint(
|
| 51 |
+
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 52 |
+
)
|
| 53 |
+
),
|
| 54 |
+
min_block_size,
|
| 55 |
+
)
|
| 56 |
+
k = max(
|
| 57 |
+
next_power_of_2(
|
| 58 |
+
V.graph.sizevars.size_hint(
|
| 59 |
+
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
| 60 |
+
)
|
| 61 |
+
),
|
| 62 |
+
min_block_size_k,
|
| 63 |
+
)
|
| 64 |
+
used = set()
|
| 65 |
+
for block_m, block_n, block_k, num_stages, num_warps in configs:
|
| 66 |
+
# shrink configs for small sizes
|
| 67 |
+
block_m = max(min(block_m, m), min_block_size)
|
| 68 |
+
block_n = max(min(block_n, n), min_block_size)
|
| 69 |
+
block_k = max(min(block_k, k), min_block_size_k)
|
| 70 |
+
# each warp computes 16x16 tile = 256
|
| 71 |
+
num_warps = min(num_warps, block_m * block_n // 256)
|
| 72 |
+
if torch.version.hip:
|
| 73 |
+
for matrix_instr_nonkdim in [0, 16]:
|
| 74 |
+
if matrix_instr_nonkdim != 0 and (
|
| 75 |
+
block_m % matrix_instr_nonkdim != 0
|
| 76 |
+
or block_n % matrix_instr_nonkdim != 0
|
| 77 |
+
):
|
| 78 |
+
# block_m and block_n must be a multiple of matrix_instr_nonkdim
|
| 79 |
+
continue
|
| 80 |
+
if (
|
| 81 |
+
block_m,
|
| 82 |
+
block_n,
|
| 83 |
+
block_k,
|
| 84 |
+
num_stages,
|
| 85 |
+
num_warps,
|
| 86 |
+
matrix_instr_nonkdim,
|
| 87 |
+
) not in used:
|
| 88 |
+
used.add(
|
| 89 |
+
(
|
| 90 |
+
block_m,
|
| 91 |
+
block_n,
|
| 92 |
+
block_k,
|
| 93 |
+
num_stages,
|
| 94 |
+
num_warps,
|
| 95 |
+
matrix_instr_nonkdim,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
yield triton_config(
|
| 99 |
+
BLOCK_M=block_m,
|
| 100 |
+
BLOCK_N=block_n,
|
| 101 |
+
BLOCK_K=block_k,
|
| 102 |
+
num_stages=num_stages,
|
| 103 |
+
num_warps=num_warps,
|
| 104 |
+
matrix_instr_nonkdim=matrix_instr_nonkdim,
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
|
| 108 |
+
used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
|
| 109 |
+
yield triton_config(
|
| 110 |
+
BLOCK_M=block_m,
|
| 111 |
+
BLOCK_N=block_n,
|
| 112 |
+
BLOCK_K=block_k,
|
| 113 |
+
num_stages=num_stages,
|
| 114 |
+
num_warps=num_warps,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 119 |
+
# will be utilised on the target platform. The configs are as follows:
|
| 120 |
+
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
|
| 121 |
+
mm_kernel_configs = (
|
| 122 |
+
[
|
| 123 |
+
{"config": (32, 32, 16, 1, 2), "cond": True},
|
| 124 |
+
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
|
| 125 |
+
{"config": (32, 64, 32, 5, 8), "cond": True},
|
| 126 |
+
{"config": (64, 32, 32, 5, 8), "cond": True},
|
| 127 |
+
{"config": (64, 32, 128, 5, 4), "cond": True},
|
| 128 |
+
{"config": (64, 64, 16, 2, 4), "cond": True},
|
| 129 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 130 |
+
{"config": (64, 64, 64, 3, 8), "cond": True},
|
| 131 |
+
{"config": (64, 64, 128, 5, 4), "cond": True},
|
| 132 |
+
{"config": (64, 128, 32, 3, 4), "cond": True},
|
| 133 |
+
{"config": (64, 128, 32, 4, 8), "cond": True},
|
| 134 |
+
{"config": (64, 128, 64, 3, 4), "cond": True},
|
| 135 |
+
{"config": (64, 128, 128, 4, 4), "cond": True},
|
| 136 |
+
{"config": (128, 64, 32, 3, 4), "cond": True},
|
| 137 |
+
{"config": (128, 64, 32, 4, 8), "cond": True},
|
| 138 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 139 |
+
{"config": (128, 128, 32, 3, 4), "cond": True},
|
| 140 |
+
{"config": (128, 128, 64, 3, 4), "cond": True},
|
| 141 |
+
{"config": (128, 128, 64, 5, 8), "cond": True},
|
| 142 |
+
]
|
| 143 |
+
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
| 144 |
+
else [
|
| 145 |
+
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
|
| 146 |
+
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
|
| 147 |
+
[16, 32, 64, 128, 256], repeat=3
|
| 148 |
+
)
|
| 149 |
+
for num_stages in [1, 2, 3, 4, 5]
|
| 150 |
+
for num_warps in [2, 4, 8]
|
| 151 |
+
]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# these are only used in tuned_mm when AutoHeuristic is enabled
|
| 155 |
+
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
|
| 156 |
+
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
|
| 157 |
+
# which saves compilation time (since less configs are autotuned) and potentially increase performance
|
| 158 |
+
# because the learned heuristic might predict a config that is not part mm_configs
|
| 159 |
+
extra_mm_kernel_configs = [
|
| 160 |
+
{"config": (16, 32, 16, 3, 2), "cond": True},
|
| 161 |
+
{"config": (16, 32, 32, 4, 2), "cond": True},
|
| 162 |
+
{"config": (16, 32, 32, 5, 2), "cond": True},
|
| 163 |
+
{"config": (64, 64, 128, 3, 4), "cond": True},
|
| 164 |
+
{"config": (128, 64, 32, 2, 2), "cond": True},
|
| 165 |
+
{"config": (128, 64, 64, 3, 8), "cond": True},
|
| 166 |
+
{"config": (128, 64, 128, 4, 8), "cond": True},
|
| 167 |
+
{"config": (128, 128, 32, 4, 4), "cond": True},
|
| 168 |
+
{"config": (128, 128, 64, 3, 8), "cond": True},
|
| 169 |
+
{"config": (128, 128, 64, 5, 4), "cond": True},
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
int8_mm_kernel_configs = [
|
| 173 |
+
{"config": (64, 64, 32, 2, 4), "cond": True},
|
| 174 |
+
{"config": (64, 128, 32, 3, 4), "cond": True},
|
| 175 |
+
{"config": (128, 64, 32, 3, 4), "cond": True},
|
| 176 |
+
{"config": (64, 128, 32, 4, 8), "cond": True},
|
| 177 |
+
{"config": (128, 64, 32, 4, 8), "cond": True},
|
| 178 |
+
{"config": (64, 32, 32, 5, 8), "cond": True},
|
| 179 |
+
{"config": (32, 64, 32, 5, 8), "cond": True},
|
| 180 |
+
{"config": (128, 128, 32, 2, 8), "cond": True},
|
| 181 |
+
{"config": (64, 64, 64, 3, 8), "cond": True},
|
| 182 |
+
# {"config": (32, 32, 128, 2, 4), "cond": True},
|
| 183 |
+
# {"config": (64, 64, 16, 2, 4), "cond": True},
|
| 184 |
+
# {"config": (32, 32, 16, 1, 2), "cond": True},
|
| 185 |
+
{"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
|
| 186 |
+
{"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
|
| 190 |
+
mixed_mm_kernel_configs_small_m = [
|
| 191 |
+
{"config": (16, 128, 256, 3, 4), "cond": True},
|
| 192 |
+
{"config": (16, 128, 256, 5, 8), "cond": True},
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
mixed_mm_kernel_configs = (
|
| 196 |
+
mm_kernel_configs + mixed_mm_kernel_configs_small_m
|
| 197 |
+
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
|
| 198 |
+
else mm_kernel_configs
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
scaled_mm_kernel_configs = [
|
| 202 |
+
{"config": (128, 256, 32, 3, 8), "cond": True},
|
| 203 |
+
{"config": (256, 128, 32, 3, 8), "cond": True},
|
| 204 |
+
{"config": (256, 64, 32, 4, 4), "cond": True},
|
| 205 |
+
{"config": (64, 256, 32, 4, 4), "cond": True},
|
| 206 |
+
{"config": (128, 128, 32, 4, 4), "cond": True},
|
| 207 |
+
{"config": (128, 64, 32, 4, 4), "cond": True},
|
| 208 |
+
{"config": (64, 128, 32, 4, 4), "cond": True},
|
| 209 |
+
{"config": (128, 32, 32, 4, 4), "cond": True},
|
| 210 |
+
{"config": (64, 32, 32, 5, 2), "cond": True},
|
| 211 |
+
{"config": (256, 128, 128, 3, 8), "cond": True},
|
| 212 |
+
{"config": (256, 64, 128, 4, 4), "cond": True},
|
| 213 |
+
{"config": (64, 256, 128, 4, 4), "cond": True},
|
| 214 |
+
{"config": (128, 128, 128, 4, 4), "cond": True},
|
| 215 |
+
{"config": (128, 64, 64, 4, 4), "cond": True},
|
| 216 |
+
{"config": (64, 128, 64, 4, 4), "cond": True},
|
| 217 |
+
{"config": (128, 32, 64, 4, 4), "cond": True},
|
| 218 |
+
{"config": (64, 32, 64, 5, 2), "cond": True},
|
| 219 |
+
{"config": (16, 32, 32, 2, 2), "cond": True},
|
| 220 |
+
{"config": (16, 64, 32, 2, 2), "cond": True},
|
| 221 |
+
{"config": (16, 128, 32, 2, 4), "cond": True},
|
| 222 |
+
{"config": (16, 256, 32, 2, 4), "cond": True},
|
| 223 |
+
{"config": (16, 32, 64, 2, 2), "cond": True},
|
| 224 |
+
{"config": (16, 64, 64, 2, 2), "cond": True},
|
| 225 |
+
{"config": (16, 128, 64, 2, 4), "cond": True},
|
| 226 |
+
{"config": (16, 256, 64, 2, 4), "cond": True},
|
| 227 |
+
{"config": (32, 32, 32, 2, 2), "cond": True},
|
| 228 |
+
{"config": (32, 64, 32, 2, 2), "cond": True},
|
| 229 |
+
{"config": (32, 128, 32, 2, 4), "cond": True},
|
| 230 |
+
{"config": (32, 256, 32, 2, 4), "cond": True},
|
| 231 |
+
{"config": (32, 32, 64, 2, 2), "cond": True},
|
| 232 |
+
{"config": (32, 64, 64, 2, 2), "cond": True},
|
| 233 |
+
{"config": (32, 128, 64, 2, 4), "cond": True},
|
| 234 |
+
{"config": (32, 256, 64, 2, 4), "cond": True},
|
| 235 |
+
{"config": (16, 32, 32, 3, 2), "cond": True},
|
| 236 |
+
{"config": (16, 64, 32, 3, 2), "cond": True},
|
| 237 |
+
{"config": (16, 128, 32, 3, 4), "cond": True},
|
| 238 |
+
{"config": (16, 256, 32, 3, 4), "cond": True},
|
| 239 |
+
{"config": (16, 32, 64, 3, 2), "cond": True},
|
| 240 |
+
{"config": (16, 64, 64, 3, 2), "cond": True},
|
| 241 |
+
{"config": (16, 128, 64, 3, 4), "cond": True},
|
| 242 |
+
{"config": (16, 256, 64, 3, 4), "cond": True},
|
| 243 |
+
{"config": (32, 32, 32, 3, 2), "cond": True},
|
| 244 |
+
{"config": (32, 64, 32, 3, 2), "cond": True},
|
| 245 |
+
{"config": (32, 128, 32, 3, 4), "cond": True},
|
| 246 |
+
{"config": (32, 256, 32, 3, 4), "cond": True},
|
| 247 |
+
{"config": (32, 32, 64, 3, 2), "cond": True},
|
| 248 |
+
{"config": (32, 64, 64, 3, 2), "cond": True},
|
| 249 |
+
{"config": (32, 128, 64, 3, 4), "cond": True},
|
| 250 |
+
{"config": (32, 256, 64, 3, 4), "cond": True},
|
| 251 |
+
{"config": (16, 32, 32, 4, 2), "cond": True},
|
| 252 |
+
{"config": (16, 64, 32, 4, 2), "cond": True},
|
| 253 |
+
{"config": (16, 128, 32, 4, 4), "cond": True},
|
| 254 |
+
{"config": (16, 256, 32, 4, 4), "cond": True},
|
| 255 |
+
{"config": (16, 32, 64, 4, 2), "cond": True},
|
| 256 |
+
{"config": (16, 64, 64, 4, 2), "cond": True},
|
| 257 |
+
{"config": (16, 128, 64, 4, 4), "cond": True},
|
| 258 |
+
{"config": (16, 256, 64, 4, 4), "cond": True},
|
| 259 |
+
{"config": (32, 32, 32, 4, 2), "cond": True},
|
| 260 |
+
{"config": (32, 64, 32, 4, 2), "cond": True},
|
| 261 |
+
{"config": (32, 128, 32, 4, 4), "cond": True},
|
| 262 |
+
{"config": (32, 256, 32, 4, 4), "cond": True},
|
| 263 |
+
{"config": (32, 32, 64, 4, 2), "cond": True},
|
| 264 |
+
{"config": (32, 64, 64, 4, 2), "cond": True},
|
| 265 |
+
{"config": (32, 128, 64, 4, 4), "cond": True},
|
| 266 |
+
{"config": (32, 256, 64, 4, 4), "cond": True},
|
| 267 |
+
{"config": (16, 32, 32, 5, 2), "cond": True},
|
| 268 |
+
{"config": (16, 64, 32, 5, 2), "cond": True},
|
| 269 |
+
{"config": (16, 128, 32, 5, 4), "cond": True},
|
| 270 |
+
{"config": (16, 256, 32, 5, 4), "cond": True},
|
| 271 |
+
{"config": (16, 32, 64, 5, 2), "cond": True},
|
| 272 |
+
{"config": (16, 64, 64, 5, 2), "cond": True},
|
| 273 |
+
{"config": (16, 128, 64, 5, 4), "cond": True},
|
| 274 |
+
{"config": (16, 256, 64, 5, 4), "cond": True},
|
| 275 |
+
{"config": (32, 32, 32, 5, 2), "cond": True},
|
| 276 |
+
{"config": (32, 64, 32, 5, 2), "cond": True},
|
| 277 |
+
{"config": (32, 128, 32, 5, 4), "cond": True},
|
| 278 |
+
{"config": (32, 256, 32, 5, 4), "cond": True},
|
| 279 |
+
{"config": (32, 32, 64, 5, 2), "cond": True},
|
| 280 |
+
{"config": (32, 64, 64, 5, 2), "cond": True},
|
| 281 |
+
{"config": (32, 128, 64, 5, 4), "cond": True},
|
| 282 |
+
{"config": (32, 256, 64, 5, 4), "cond": True},
|
| 283 |
+
{"config": (16, 32, 32, 6, 2), "cond": True},
|
| 284 |
+
{"config": (16, 64, 32, 6, 2), "cond": True},
|
| 285 |
+
{"config": (16, 128, 32, 6, 4), "cond": True},
|
| 286 |
+
{"config": (16, 256, 32, 6, 4), "cond": True},
|
| 287 |
+
{"config": (16, 32, 64, 6, 2), "cond": True},
|
| 288 |
+
{"config": (16, 64, 64, 6, 2), "cond": True},
|
| 289 |
+
{"config": (16, 128, 64, 6, 4), "cond": True},
|
| 290 |
+
{"config": (16, 256, 64, 6, 4), "cond": True},
|
| 291 |
+
{"config": (32, 32, 32, 6, 2), "cond": True},
|
| 292 |
+
{"config": (32, 64, 32, 6, 2), "cond": True},
|
| 293 |
+
{"config": (32, 128, 32, 6, 4), "cond": True},
|
| 294 |
+
{"config": (32, 256, 32, 6, 4), "cond": True},
|
| 295 |
+
{"config": (32, 32, 64, 6, 2), "cond": True},
|
| 296 |
+
{"config": (32, 64, 64, 6, 2), "cond": True},
|
| 297 |
+
{"config": (32, 128, 64, 6, 4), "cond": True},
|
| 298 |
+
{"config": (32, 256, 64, 6, 4), "cond": True},
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# Create filtered list of configs based on cond evaluation
|
| 303 |
+
mm_platform_configs = tuple(
|
| 304 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 305 |
+
for config in mm_kernel_configs
|
| 306 |
+
if config["cond"]
|
| 307 |
+
)
|
| 308 |
+
extra_mm_platform_configs = tuple(
|
| 309 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 310 |
+
for config in extra_mm_kernel_configs
|
| 311 |
+
if config["cond"]
|
| 312 |
+
)
|
| 313 |
+
int8_platform_configs = tuple(
|
| 314 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 315 |
+
for config in int8_mm_kernel_configs
|
| 316 |
+
if config["cond"]
|
| 317 |
+
)
|
| 318 |
+
mixed_mm_platform_configs = tuple(
|
| 319 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 320 |
+
for config in mixed_mm_kernel_configs
|
| 321 |
+
if config["cond"]
|
| 322 |
+
)
|
| 323 |
+
scaled_mm_platform_configs = tuple(
|
| 324 |
+
cast(Tuple[int, int, int, int, int], config["config"])
|
| 325 |
+
for config in scaled_mm_kernel_configs
|
| 326 |
+
if config["cond"]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# On ROCm convert num_stages to 0 to enable software pipelining
|
| 330 |
+
if torch.version.hip:
|
| 331 |
+
mm_platform_configs = tuple(
|
| 332 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 333 |
+
for config in mm_platform_configs
|
| 334 |
+
)
|
| 335 |
+
extra_mm_platform_configs = tuple(
|
| 336 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 337 |
+
for config in extra_mm_platform_configs
|
| 338 |
+
)
|
| 339 |
+
int8_platform_configs = tuple(
|
| 340 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 341 |
+
for config in mm_platform_configs
|
| 342 |
+
)
|
| 343 |
+
mixed_mm_platform_configs = tuple(
|
| 344 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 345 |
+
for config in mixed_mm_platform_configs
|
| 346 |
+
)
|
| 347 |
+
scaled_mm_platform_configs = tuple(
|
| 348 |
+
(config[0], config[1], config[2], 0, config[4])
|
| 349 |
+
for config in scaled_mm_platform_configs
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
mm_configs = functools.partial(
|
| 353 |
+
filtered_configs,
|
| 354 |
+
configs=mm_platform_configs,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
extra_mm_configs = functools.partial(
|
| 358 |
+
filtered_configs,
|
| 359 |
+
configs=extra_mm_platform_configs,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
int8_mm_configs = functools.partial(
|
| 363 |
+
filtered_configs,
|
| 364 |
+
configs=int8_platform_configs,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
mixed_mm_configs = functools.partial(
|
| 368 |
+
filtered_configs,
|
| 369 |
+
configs=mixed_mm_platform_configs,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
scaled_mm_configs = functools.partial(
|
| 373 |
+
filtered_configs,
|
| 374 |
+
configs=scaled_mm_platform_configs,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def mm_grid(m, n, meta):
|
| 379 |
+
"""
|
| 380 |
+
The CUDA grid size for matmul triton templates.
|
| 381 |
+
"""
|
| 382 |
+
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def acc_type(dtype):
|
| 386 |
+
if dtype in (torch.float16, torch.bfloat16):
|
| 387 |
+
return "tl.float32"
|
| 388 |
+
return f"tl.{dtype}".replace("torch.", "")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
|
| 392 |
+
"""
|
| 393 |
+
Common options to matmul triton templates.
|
| 394 |
+
"""
|
| 395 |
+
even_k_symbolic = (
|
| 396 |
+
# it isn't worth guarding on this
|
| 397 |
+
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
|
| 398 |
+
== config.kwargs["BLOCK_K"]
|
| 399 |
+
)
|
| 400 |
+
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
| 401 |
+
not inductor_config.force_same_precision
|
| 402 |
+
or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
|
| 403 |
+
)
|
| 404 |
+
return dict(
|
| 405 |
+
GROUP_M=8,
|
| 406 |
+
EVEN_K=even_k_symbolic,
|
| 407 |
+
ALLOW_TF32=allow_tf32,
|
| 408 |
+
ACC_TYPE=acc_type(layout.dtype),
|
| 409 |
+
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
|
| 410 |
+
num_stages=config.num_stages,
|
| 411 |
+
num_warps=config.num_warps,
|
| 412 |
+
**config.kwargs,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def mm_args(
|
| 417 |
+
mat1,
|
| 418 |
+
mat2,
|
| 419 |
+
*others,
|
| 420 |
+
layout=None,
|
| 421 |
+
out_dtype=None,
|
| 422 |
+
use_4x2_dim=False,
|
| 423 |
+
mat2_transposed=False,
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Common arg processing for mm,bmm,addmm,etc
|
| 427 |
+
"""
|
| 428 |
+
mat1, mat2 = realize_inputs(mat1, mat2)
|
| 429 |
+
*b1, m, k1 = mat1.get_size()
|
| 430 |
+
if mat2_transposed:
|
| 431 |
+
*b2, n, k2 = mat2.get_size()
|
| 432 |
+
else:
|
| 433 |
+
*b2, k2, n = mat2.get_size()
|
| 434 |
+
b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
|
| 435 |
+
if use_4x2_dim:
|
| 436 |
+
k2 = k2 * 2
|
| 437 |
+
k = V.graph.sizevars.guard_equals(k1, k2)
|
| 438 |
+
if layout is None:
|
| 439 |
+
from torch._inductor.ir import FixedLayout
|
| 440 |
+
|
| 441 |
+
if out_dtype is None:
|
| 442 |
+
out_dtype = mat1.get_dtype()
|
| 443 |
+
|
| 444 |
+
layout = FixedLayout(
|
| 445 |
+
mat1.get_device(),
|
| 446 |
+
out_dtype,
|
| 447 |
+
[*b, m, n],
|
| 448 |
+
)
|
| 449 |
+
else:
|
| 450 |
+
assert out_dtype is None, "out_dtype is ignored if layout is specified."
|
| 451 |
+
from ..lowering import expand
|
| 452 |
+
|
| 453 |
+
others = [realize_inputs(expand(x, layout.size)) for x in others]
|
| 454 |
+
|
| 455 |
+
return [m, n, k, layout, mat1, mat2, *others]
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def addmm_epilogue(dtype, alpha, beta):
|
| 459 |
+
def epilogue(acc, bias):
|
| 460 |
+
if alpha != 1:
|
| 461 |
+
acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
|
| 462 |
+
if beta != 1:
|
| 463 |
+
bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
|
| 464 |
+
return V.ops.add(acc, bias)
|
| 465 |
+
|
| 466 |
+
return epilogue
|
.venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ..lowering import lowerings
|
| 7 |
+
from ..select_algorithm import (
|
| 8 |
+
autotune_select_algorithm,
|
| 9 |
+
ExternKernelChoice,
|
| 10 |
+
TritonTemplate,
|
| 11 |
+
)
|
| 12 |
+
from ..utils import use_aten_gemm_kernels, use_triton_template
|
| 13 |
+
from ..virtualized import V
|
| 14 |
+
from .mm_common import mm_args, mm_grid, mm_options
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
aten = torch.ops.aten
|
| 18 |
+
|
| 19 |
+
aten_mm_plus_mm = ExternKernelChoice(
|
| 20 |
+
torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
mm_plus_mm_template = TritonTemplate(
|
| 24 |
+
name="mm_plus_mm",
|
| 25 |
+
grid=mm_grid,
|
| 26 |
+
debug=False,
|
| 27 |
+
source=r"""
|
| 28 |
+
{{def_kernel("A", "B", "C", "D")}}
|
| 29 |
+
M = {{size("A", 0)}}
|
| 30 |
+
N = {{size("B", 1)}}
|
| 31 |
+
K1 = {{size("A", 1)}}
|
| 32 |
+
if M * N == 0:
|
| 33 |
+
# early exit due to zero-size input(s)
|
| 34 |
+
return
|
| 35 |
+
# K2 = {{size("C", 1)}}
|
| 36 |
+
stride_am = {{stride("A", 0)}}
|
| 37 |
+
stride_ak = {{stride("A", 1)}}
|
| 38 |
+
stride_bk = {{stride("B", 0)}}
|
| 39 |
+
stride_bn = {{stride("B", 1)}}
|
| 40 |
+
stride_cm = {{stride("C", 0)}}
|
| 41 |
+
stride_ck = {{stride("C", 1)}}
|
| 42 |
+
stride_dk = {{stride("D", 0)}}
|
| 43 |
+
stride_dn = {{stride("D", 1)}}
|
| 44 |
+
|
| 45 |
+
# based on triton.ops.matmul
|
| 46 |
+
pid = tl.program_id(0)
|
| 47 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 48 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 49 |
+
|
| 50 |
+
# re-order program ID for better L2 performance
|
| 51 |
+
width = GROUP_M * grid_n
|
| 52 |
+
group_id = pid // width
|
| 53 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 54 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 55 |
+
pid_n = (pid % width) // (group_size)
|
| 56 |
+
|
| 57 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 58 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 59 |
+
|
| 60 |
+
if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1))
|
| 61 |
+
and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))):
|
| 62 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 63 |
+
else:
|
| 64 |
+
ram = rm % M
|
| 65 |
+
|
| 66 |
+
if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1))
|
| 67 |
+
and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))):
|
| 68 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 69 |
+
else:
|
| 70 |
+
rbn = rn % N
|
| 71 |
+
|
| 72 |
+
rk = tl.arange(0, BLOCK_K)
|
| 73 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 74 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 75 |
+
C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
|
| 76 |
+
D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
|
| 77 |
+
|
| 78 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 79 |
+
for k1 in range(K1, 0, -BLOCK_K):
|
| 80 |
+
# First matmul with A @ B
|
| 81 |
+
if EVEN_K:
|
| 82 |
+
a = tl.load(A)
|
| 83 |
+
b = tl.load(B)
|
| 84 |
+
else:
|
| 85 |
+
a = tl.load(A, mask=rk[None, :] < k1, other=0.)
|
| 86 |
+
b = tl.load(B, mask=rk[:, None] < k1, other=0.)
|
| 87 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 88 |
+
A += BLOCK_K * stride_ak
|
| 89 |
+
B += BLOCK_K * stride_bk
|
| 90 |
+
|
| 91 |
+
for k2 in range(K1, 0, -BLOCK_K):
|
| 92 |
+
|
| 93 |
+
# Second matmul with C @ D
|
| 94 |
+
if EVEN_K:
|
| 95 |
+
c = tl.load(C)
|
| 96 |
+
d = tl.load(D)
|
| 97 |
+
else:
|
| 98 |
+
c = tl.load(C, mask=rk[None, :] < k2, other=0.)
|
| 99 |
+
d = tl.load(D, mask=rk[:, None] < k2, other=0.)
|
| 100 |
+
acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
|
| 101 |
+
C += BLOCK_K * stride_ck
|
| 102 |
+
D += BLOCK_K * stride_dk
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
idx_m = rm[:, None]
|
| 106 |
+
idx_n = rn[None, :]
|
| 107 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 108 |
+
|
| 109 |
+
# inductor generates a suffix
|
| 110 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 111 |
+
""",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@functools.lru_cache(None)
|
| 116 |
+
def mm_configs():
|
| 117 |
+
import triton
|
| 118 |
+
|
| 119 |
+
# List of dictionaries to store the kernel configs. Configs that evaluate to true
|
| 120 |
+
# will be utilised on the target platform
|
| 121 |
+
mm_triton_configs = [
|
| 122 |
+
{
|
| 123 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 124 |
+
"num_stages": 2,
|
| 125 |
+
"num_warps": 4,
|
| 126 |
+
"cond": True,
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 130 |
+
"num_stages": 3,
|
| 131 |
+
"num_warps": 8,
|
| 132 |
+
"cond": True,
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 136 |
+
"num_stages": 4,
|
| 137 |
+
"num_warps": 16,
|
| 138 |
+
"cond": True,
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
|
| 142 |
+
"num_stages": 4,
|
| 143 |
+
"num_warps": 8,
|
| 144 |
+
"cond": True,
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
|
| 148 |
+
"num_stages": 4,
|
| 149 |
+
"num_warps": 8,
|
| 150 |
+
"cond": True,
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
|
| 154 |
+
"num_stages": 1,
|
| 155 |
+
"num_warps": 8,
|
| 156 |
+
"cond": True,
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
|
| 160 |
+
"num_stages": 1,
|
| 161 |
+
"num_warps": 8,
|
| 162 |
+
"cond": True,
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
|
| 166 |
+
"num_stages": 1,
|
| 167 |
+
"num_warps": 8,
|
| 168 |
+
"cond": torch.version.hip is None,
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
|
| 172 |
+
"num_stages": 2,
|
| 173 |
+
"num_warps": 4,
|
| 174 |
+
"cond": True,
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
|
| 178 |
+
"num_stages": 1,
|
| 179 |
+
"num_warps": 2,
|
| 180 |
+
"cond": True,
|
| 181 |
+
},
|
| 182 |
+
]
|
| 183 |
+
|
| 184 |
+
# Filter out configs in which cond evaluates to true
|
| 185 |
+
# On ROCm convert num_stages to 1 as pipelining provides no benefit
|
| 186 |
+
if torch.version.hip:
|
| 187 |
+
filtered_configs = [
|
| 188 |
+
triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
|
| 189 |
+
for c in mm_triton_configs
|
| 190 |
+
if c["cond"]
|
| 191 |
+
]
|
| 192 |
+
else:
|
| 193 |
+
filtered_configs = [
|
| 194 |
+
triton.Config(
|
| 195 |
+
c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
|
| 196 |
+
)
|
| 197 |
+
for c in mm_triton_configs
|
| 198 |
+
if c["cond"]
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
return filtered_configs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
|
| 205 |
+
"""
|
| 206 |
+
Computes mm(mat1, mat2) + mm(mat3, mat4)
|
| 207 |
+
"""
|
| 208 |
+
m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
|
| 209 |
+
m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
|
| 210 |
+
# Optimization is optional, because we can always just not do the fusion
|
| 211 |
+
if (
|
| 212 |
+
m1 * n1 == 0
|
| 213 |
+
or m2 * n2 == 0
|
| 214 |
+
or not V.graph.sizevars.statically_known_list_equals(
|
| 215 |
+
mat1.get_size(), mat3.get_size()
|
| 216 |
+
)
|
| 217 |
+
or not V.graph.sizevars.statically_known_list_equals(
|
| 218 |
+
mat2.get_size(), mat4.get_size()
|
| 219 |
+
)
|
| 220 |
+
):
|
| 221 |
+
# TODO(jansel): support different K values when this is fixed:
|
| 222 |
+
# https://github.com/openai/triton/issues/967
|
| 223 |
+
return lowerings[aten.add](
|
| 224 |
+
lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
assert layout1 == layout2
|
| 228 |
+
# options to tune from
|
| 229 |
+
choices = (
|
| 230 |
+
[aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
|
| 231 |
+
if use_aten_gemm_kernels()
|
| 232 |
+
else []
|
| 233 |
+
)
|
| 234 |
+
if use_triton_template(layout1):
|
| 235 |
+
for config in mm_configs():
|
| 236 |
+
# see https://github.com/openai/triton/issues/1298
|
| 237 |
+
# BLOCK_K = K causes llvm error
|
| 238 |
+
if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1):
|
| 239 |
+
mm_plus_mm_template.maybe_append_choice(
|
| 240 |
+
choices,
|
| 241 |
+
input_nodes=(mat1, mat2, mat3, mat4),
|
| 242 |
+
layout=layout1,
|
| 243 |
+
**mm_options(config, m1, n1, k1, layout1),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return autotune_select_algorithm(
|
| 247 |
+
"mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
|
| 248 |
+
)
|
.venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import sympy
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .. import config as inductor_config
|
| 9 |
+
from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox
|
| 10 |
+
from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
|
| 11 |
+
from ..select_algorithm import (
|
| 12 |
+
autotune_select_algorithm,
|
| 13 |
+
ExternKernelChoice,
|
| 14 |
+
NoValidChoicesError,
|
| 15 |
+
realize_inputs,
|
| 16 |
+
TritonTemplate,
|
| 17 |
+
)
|
| 18 |
+
from ..utils import use_aten_gemm_kernels, use_triton_template
|
| 19 |
+
from .mm import _is_static_problem # TODO(yangsiyu) move to mm_common
|
| 20 |
+
from .mm_common import mm_args, mm_grid, scaled_mm_configs
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
log = logging.getLogger(__name__)
|
| 24 |
+
aten = torch.ops.aten
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
scaled_mm_template = TritonTemplate(
|
| 28 |
+
name="scaled_mm",
|
| 29 |
+
grid=mm_grid,
|
| 30 |
+
source=r"""
|
| 31 |
+
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
|
| 32 |
+
M = {{size("A", 0)}}
|
| 33 |
+
N = {{size("B", 1)}}
|
| 34 |
+
K = {{size("A", 1)}}
|
| 35 |
+
if M * N == 0:
|
| 36 |
+
# early exit due to zero-size input(s)
|
| 37 |
+
return
|
| 38 |
+
stride_am = {{stride("A", 0)}}
|
| 39 |
+
stride_ak = {{stride("A", 1)}}
|
| 40 |
+
stride_bk = {{stride("B", 0)}}
|
| 41 |
+
stride_bn = {{stride("B", 1)}}
|
| 42 |
+
|
| 43 |
+
# based on triton.ops.matmul
|
| 44 |
+
pid = tl.program_id(0)
|
| 45 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 46 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 47 |
+
|
| 48 |
+
# re-order program ID for better L2 performance
|
| 49 |
+
width = GROUP_M * grid_n
|
| 50 |
+
group_id = pid // width
|
| 51 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 52 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 53 |
+
pid_n = (pid % width) // (group_size)
|
| 54 |
+
|
| 55 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 56 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 57 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 58 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 59 |
+
rk = tl.arange(0, BLOCK_K)
|
| 60 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 61 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 62 |
+
|
| 63 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 64 |
+
for k in range(K, 0, -BLOCK_K):
|
| 65 |
+
if EVEN_K:
|
| 66 |
+
a = tl.load(A)
|
| 67 |
+
b = tl.load(B)
|
| 68 |
+
else:
|
| 69 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 70 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 71 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 72 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 73 |
+
if USE_FAST_ACCUM:
|
| 74 |
+
acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
|
| 75 |
+
else:
|
| 76 |
+
acc += tl.dot(a, b, out_dtype=ACC_TYPE)
|
| 77 |
+
A += BLOCK_K * stride_ak
|
| 78 |
+
B += BLOCK_K * stride_bk
|
| 79 |
+
|
| 80 |
+
if SCALING_ROWWISE:
|
| 81 |
+
inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
|
| 82 |
+
inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
|
| 83 |
+
inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
|
| 84 |
+
acc *= inv_scale_row
|
| 85 |
+
else:
|
| 86 |
+
# for tensor-wise scaling, the scales are scalars
|
| 87 |
+
inv_a_scale = tl.load(A_inverse_scale)
|
| 88 |
+
inv_b_scale = tl.load(B_inverse_scale)
|
| 89 |
+
inv_scale = inv_a_scale * inv_b_scale
|
| 90 |
+
acc *= inv_scale
|
| 91 |
+
|
| 92 |
+
# rematerialize rm and rn to save registers
|
| 93 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 94 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 95 |
+
|
| 96 |
+
idx_m = rm[:, None]
|
| 97 |
+
idx_n = rn[None, :]
|
| 98 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 99 |
+
|
| 100 |
+
# inductor generates a suffix
|
| 101 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 102 |
+
""",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Inductor does not allow optional tensor input arguments currently (pass None as an
|
| 107 |
+
# input node to template choices), but since for _scaled_mm there is only one such arg
|
| 108 |
+
# (bias), work around by having a second template when bias is provided.
|
| 109 |
+
scaled_mm_bias_template = TritonTemplate(
|
| 110 |
+
name="scaled_mm_bias",
|
| 111 |
+
grid=mm_grid,
|
| 112 |
+
source=r"""
|
| 113 |
+
{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}}
|
| 114 |
+
M = {{size("A", 0)}}
|
| 115 |
+
N = {{size("B", 1)}}
|
| 116 |
+
K = {{size("A", 1)}}
|
| 117 |
+
if M * N == 0:
|
| 118 |
+
# early exit due to zero-size input(s)
|
| 119 |
+
return
|
| 120 |
+
stride_am = {{stride("A", 0)}}
|
| 121 |
+
stride_ak = {{stride("A", 1)}}
|
| 122 |
+
stride_bk = {{stride("B", 0)}}
|
| 123 |
+
stride_bn = {{stride("B", 1)}}
|
| 124 |
+
|
| 125 |
+
# based on triton.ops.matmul
|
| 126 |
+
pid = tl.program_id(0)
|
| 127 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 128 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 129 |
+
|
| 130 |
+
# re-order program ID for better L2 performance
|
| 131 |
+
width = GROUP_M * grid_n
|
| 132 |
+
group_id = pid // width
|
| 133 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 134 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 135 |
+
pid_n = (pid % width) // (group_size)
|
| 136 |
+
|
| 137 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 138 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 139 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 140 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 141 |
+
rk = tl.arange(0, BLOCK_K)
|
| 142 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 143 |
+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
| 144 |
+
|
| 145 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 146 |
+
for k in range(K, 0, -BLOCK_K):
|
| 147 |
+
if EVEN_K:
|
| 148 |
+
a = tl.load(A)
|
| 149 |
+
b = tl.load(B)
|
| 150 |
+
else:
|
| 151 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 152 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 153 |
+
if B_PROLOGUE_CAST_TYPE is not None:
|
| 154 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 155 |
+
if USE_FAST_ACCUM:
|
| 156 |
+
acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
|
| 157 |
+
else:
|
| 158 |
+
acc += tl.dot(a, b, out_dtype=ACC_TYPE)
|
| 159 |
+
A += BLOCK_K * stride_ak
|
| 160 |
+
B += BLOCK_K * stride_bk
|
| 161 |
+
|
| 162 |
+
if SCALING_ROWWISE:
|
| 163 |
+
inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
|
| 164 |
+
inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
|
| 165 |
+
inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
|
| 166 |
+
acc *= inv_scale_row
|
| 167 |
+
else:
|
| 168 |
+
# for tensor-wise scaling, the scales are scalars
|
| 169 |
+
inv_a_scale = tl.load(A_inverse_scale)
|
| 170 |
+
inv_b_scale = tl.load(B_inverse_scale)
|
| 171 |
+
inv_scale = inv_a_scale * inv_b_scale
|
| 172 |
+
acc *= inv_scale
|
| 173 |
+
|
| 174 |
+
# rematerialize rm and rn to save registers
|
| 175 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 176 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 177 |
+
|
| 178 |
+
# bias
|
| 179 |
+
bias = tl.load(bias_ptr + rn, mask=rn < N)
|
| 180 |
+
acc += bias
|
| 181 |
+
|
| 182 |
+
idx_m = rm[:, None]
|
| 183 |
+
idx_n = rn[None, :]
|
| 184 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 185 |
+
|
| 186 |
+
# inductor generates a suffix
|
| 187 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 188 |
+
""",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool:
|
| 196 |
+
# Same sized scales are compatable
|
| 197 |
+
if len(size_a) == len(size_b):
|
| 198 |
+
return True
|
| 199 |
+
|
| 200 |
+
# Both need to be scalars or len(1) tensors
|
| 201 |
+
if len(size_a) <= 1 and len(size_b) <= 1:
|
| 202 |
+
return True
|
| 203 |
+
|
| 204 |
+
return False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def scaled_mm_options( # type: ignore[no-untyped-def]
|
| 208 |
+
config, # triton.Config
|
| 209 |
+
sym_m: sympy.core.numbers.Integer,
|
| 210 |
+
sym_n: sympy.core.numbers.Integer,
|
| 211 |
+
sym_k: sympy.core.numbers.Integer,
|
| 212 |
+
layout: Layout,
|
| 213 |
+
scale_a: StorageBox,
|
| 214 |
+
scale_b: StorageBox,
|
| 215 |
+
use_fast_accum: bool,
|
| 216 |
+
b_prologue_cast_type: Optional[str] = None,
|
| 217 |
+
) -> Dict[str, Any]:
|
| 218 |
+
even_k_symbolic = (
|
| 219 |
+
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
size_a, size_b = scale_a.get_size(), scale_b.get_size()
|
| 223 |
+
assert are_compatible_scales(size_a, size_b), (
|
| 224 |
+
"Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
|
| 225 |
+
f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
|
| 226 |
+
)
|
| 227 |
+
return dict(
|
| 228 |
+
GROUP_M=8,
|
| 229 |
+
EVEN_K=even_k_symbolic,
|
| 230 |
+
ACC_TYPE="tl.float32",
|
| 231 |
+
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
|
| 232 |
+
USE_FAST_ACCUM=use_fast_accum,
|
| 233 |
+
num_stages=config.num_stages,
|
| 234 |
+
num_warps=config.num_warps,
|
| 235 |
+
# tensor-wise scaling if scalar scales
|
| 236 |
+
SCALING_ROWWISE=len(scale_a.get_size()) == 2,
|
| 237 |
+
**config.kwargs,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
|
| 245 |
+
def tuned_scaled_mm(
|
| 246 |
+
mat_a: TensorBox,
|
| 247 |
+
mat_b: TensorBox,
|
| 248 |
+
scale_a: TensorBox,
|
| 249 |
+
scale_b: TensorBox,
|
| 250 |
+
bias: Optional[TensorBox] = None,
|
| 251 |
+
scale_result: Optional[TensorBox] = None,
|
| 252 |
+
out_dtype: Optional[torch.dtype] = None,
|
| 253 |
+
use_fast_accum: bool = False,
|
| 254 |
+
layout: Optional[Layout] = None,
|
| 255 |
+
) -> TensorBox:
|
| 256 |
+
m, n, k, layout, mat_a, mat_b = mm_args(
|
| 257 |
+
mat_a, mat_b, layout=layout, out_dtype=out_dtype
|
| 258 |
+
)
|
| 259 |
+
scale_a, scale_b = realize_inputs(scale_a, scale_b)
|
| 260 |
+
|
| 261 |
+
input_nodes: Tuple[Any, ...]
|
| 262 |
+
# workaround for Inductor not supporting optional tensor input arguments
|
| 263 |
+
if bias is None:
|
| 264 |
+
input_nodes = (mat_a, mat_b, scale_a, scale_b)
|
| 265 |
+
triton_template = scaled_mm_template
|
| 266 |
+
else:
|
| 267 |
+
bias = realize_inputs(bias)
|
| 268 |
+
input_nodes = (mat_a, mat_b, scale_a, scale_b, bias)
|
| 269 |
+
triton_template = scaled_mm_bias_template
|
| 270 |
+
|
| 271 |
+
aten_choice = aten__fp8_mm.bind(
|
| 272 |
+
input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
choices: List[ChoiceCaller] = []
|
| 276 |
+
if use_aten_gemm_kernels():
|
| 277 |
+
choices.append(aten_choice)
|
| 278 |
+
|
| 279 |
+
static_shape, is_nonzero = _is_static_problem([mat_a, mat_b], layout)
|
| 280 |
+
if is_nonzero and use_triton_template(layout, enable_float8=True):
|
| 281 |
+
for config in scaled_mm_configs(m, n, k):
|
| 282 |
+
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
|
| 283 |
+
continue # Triton crashes in this case
|
| 284 |
+
kwargs = scaled_mm_options(
|
| 285 |
+
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
| 286 |
+
)
|
| 287 |
+
# possibly appends a TritonTemplateCaller to choices
|
| 288 |
+
triton_template.maybe_append_choice(
|
| 289 |
+
choices,
|
| 290 |
+
input_nodes=input_nodes,
|
| 291 |
+
layout=layout,
|
| 292 |
+
**kwargs,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if (
|
| 296 |
+
len(choices) == 0
|
| 297 |
+
and not use_aten_gemm_kernels()
|
| 298 |
+
and inductor_config.autotune_fallback_to_aten
|
| 299 |
+
):
|
| 300 |
+
log.warning("No choices for scaled_mm, using ATen backend as fallback")
|
| 301 |
+
return aten_choice.output_node()
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
|
| 305 |
+
except NoValidChoicesError:
|
| 306 |
+
if not inductor_config.autotune_fallback_to_aten:
|
| 307 |
+
raise
|
| 308 |
+
log.warning(
|
| 309 |
+
"All choices for scaled_mm were invalid, using ATen backend as fallback"
|
| 310 |
+
)
|
| 311 |
+
return aten_choice.output_node()
|
.venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
| 6 |
+
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from ..ir import ChoiceCaller
|
| 11 |
+
|
| 12 |
+
log = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
uint4x2_mixed_mm_template = TritonTemplate(
|
| 15 |
+
name="uint4x2_mixed_mm",
|
| 16 |
+
grid=mm_grid,
|
| 17 |
+
source=r"""
|
| 18 |
+
{{def_kernel("A", "B")}}
|
| 19 |
+
M = {{size("A", 0)}}
|
| 20 |
+
N = {{size("B", 1)}}
|
| 21 |
+
K = {{size("A", 1)}}
|
| 22 |
+
stride_am = {{stride("A", 0)}}
|
| 23 |
+
stride_ak = {{stride("A", 1)}}
|
| 24 |
+
stride_bk = {{stride("B", 0)}}
|
| 25 |
+
stride_bn = {{stride("B", 1)}}
|
| 26 |
+
|
| 27 |
+
# based on triton.ops.matmul
|
| 28 |
+
pid = tl.program_id(0)
|
| 29 |
+
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
| 30 |
+
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
| 31 |
+
|
| 32 |
+
# re-order program ID for better L2 performance
|
| 33 |
+
width = GROUP_M * grid_n
|
| 34 |
+
group_id = pid // width
|
| 35 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
| 36 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
| 37 |
+
pid_n = (pid % width) // (group_size)
|
| 38 |
+
|
| 39 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 40 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 41 |
+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
| 42 |
+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
| 43 |
+
rk = tl.arange(0, BLOCK_K)
|
| 44 |
+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
| 45 |
+
B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
|
| 46 |
+
b_shifts = 4*(rk%2)
|
| 47 |
+
b_subs = 8*(1-(rk%2))
|
| 48 |
+
|
| 49 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
| 50 |
+
for k in range(K, 0, -BLOCK_K):
|
| 51 |
+
if EVEN_K:
|
| 52 |
+
a = tl.load(A)
|
| 53 |
+
b = tl.load(B)
|
| 54 |
+
else:
|
| 55 |
+
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
| 56 |
+
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
| 57 |
+
b = ((b >> b_shifts[:, None]) & 0xF) - 8
|
| 58 |
+
b = b.to(B_PROLOGUE_CAST_TYPE)
|
| 59 |
+
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
|
| 60 |
+
A += BLOCK_K * stride_ak
|
| 61 |
+
B += BLOCK_K//2 * stride_bk
|
| 62 |
+
|
| 63 |
+
# rematerialize rm and rn to save registers
|
| 64 |
+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 65 |
+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 66 |
+
idx_m = rm[:, None]
|
| 67 |
+
idx_n = rn[None, :]
|
| 68 |
+
mask = (idx_m < M) & (idx_n < N)
|
| 69 |
+
|
| 70 |
+
# inductor generates a suffix
|
| 71 |
+
{{store_output(("idx_m", "idx_n"), "acc", "mask")}}
|
| 72 |
+
""",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
|
| 77 |
+
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
|
| 78 |
+
choices: List[ChoiceCaller] = []
|
| 79 |
+
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
| 80 |
+
for config in mm_configs(m, n, k):
|
| 81 |
+
uint4x2_mixed_mm_template.maybe_append_choice(
|
| 82 |
+
choices,
|
| 83 |
+
input_nodes=(mat1, mat2),
|
| 84 |
+
layout=layout,
|
| 85 |
+
**mm_options(config, m, n, k, layout, b_prologue_cast_type),
|
| 86 |
+
)
|
| 87 |
+
return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
|