ayousanz commited on
Commit
fb9e29e
·
verified ·
1 Parent(s): cd32371

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc +0 -0
  2. .venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc +0 -0
  3. .venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc +0 -0
  4. .venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc +0 -0
  5. .venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc +0 -0
  6. .venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc +0 -0
  7. .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py +296 -0
  8. .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py +321 -0
  9. .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +149 -0
  10. .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +109 -0
  11. .venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py +0 -0
  12. .venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py +32 -0
  13. .venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py +91 -0
  14. .venv/Lib/site-packages/torch/_inductor/codegen/common.py +2167 -0
  15. .venv/Lib/site-packages/torch/_inductor/codegen/cpp.py +0 -0
  16. .venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py +1043 -0
  17. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py +0 -0
  18. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py +173 -0
  19. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +204 -0
  20. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py +203 -0
  21. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py +219 -0
  22. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py +129 -0
  23. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py +209 -0
  24. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py +227 -0
  25. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py +598 -0
  26. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py +243 -0
  27. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py +452 -0
  28. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py +208 -0
  29. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py +173 -0
  30. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py +189 -0
  31. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py +189 -0
  32. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +177 -0
  33. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +193 -0
  34. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +220 -0
  35. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +204 -0
  36. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +220 -0
  37. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py +52 -0
  38. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py +44 -0
  39. .venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py +44 -0
  40. .venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py +0 -0
  41. .venv/Lib/site-packages/torch/_inductor/kernel/__init__.py +1 -0
  42. .venv/Lib/site-packages/torch/_inductor/kernel/bmm.py +192 -0
  43. .venv/Lib/site-packages/torch/_inductor/kernel/conv.py +679 -0
  44. .venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py +1843 -0
  45. .venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py +570 -0
  46. .venv/Lib/site-packages/torch/_inductor/kernel/mm.py +776 -0
  47. .venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py +466 -0
  48. .venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py +248 -0
  49. .venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py +311 -0
  50. .venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py +87 -0
.venv/Lib/site-packages/torch/_inductor/__pycache__/codecache.cpython-39.pyc ADDED
Binary file (92.1 kB). View file
 
.venv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc ADDED
Binary file (17.3 kB). View file
 
.venv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc ADDED
Binary file (35.9 kB). View file
 
.venv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc ADDED
Binary file (9.92 kB). View file
 
.venv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
.venv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc ADDED
Binary file (4.59 kB). View file
 
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MMRankingA100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 166912
27
+ and str(metadata.device_capa) == "(8, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
44
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
45
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
46
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
50
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
51
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
53
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
57
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
58
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
59
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
61
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
62
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
63
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
64
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
65
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
66
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
67
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
68
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
69
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
70
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
71
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
72
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
73
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
74
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
75
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
76
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
77
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
78
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
79
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
80
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
81
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
82
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
83
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
84
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
85
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
86
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
87
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
88
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
89
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
90
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
91
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
92
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
93
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
94
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
95
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
96
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
97
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
98
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
99
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
100
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
101
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
102
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
103
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
104
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
105
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
106
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
107
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
108
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
109
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
110
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
111
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
112
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
113
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
114
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
115
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
116
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
117
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
118
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
119
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
120
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
121
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
122
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
123
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
124
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
125
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
126
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
127
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
128
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
129
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
130
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
131
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
132
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
133
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
134
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
135
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
136
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
137
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
138
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
139
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
140
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
141
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
142
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
143
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
144
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
145
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
146
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
147
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
148
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
149
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=2')
150
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
151
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
152
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
153
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
154
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
155
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
156
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
157
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
158
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
159
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
160
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
161
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
162
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
163
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
164
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
165
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
166
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
167
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
168
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
169
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
170
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
171
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
172
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
173
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
174
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
175
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
176
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
177
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
178
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
179
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
180
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
181
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
182
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
183
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
184
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
185
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
186
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
187
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
188
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
189
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
190
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
191
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
192
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
193
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=8')
194
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
195
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
196
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
197
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
198
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
199
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
200
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
201
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
202
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
203
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
204
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
205
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
206
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
207
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
208
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
209
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
210
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
211
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
212
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
213
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
214
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
215
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
216
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
217
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
218
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
219
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
220
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
221
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
222
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
223
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
224
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
225
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
226
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
227
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
228
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
229
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
230
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
231
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
232
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
233
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
234
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
235
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
236
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
237
+
238
+ def get_name(self) -> str:
239
+ return 'mm'
240
+
241
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
242
+ if context.get_value('arith_intensity') <= 52.6245059967041:
243
+ if context.get_value('n') <= 34.0:
244
+ if context.get_value('n') <= 18.0:
245
+ if context.get_value('k*n') <= 312.0:
246
+ return [(0.093, 12), (0.081, 16), (0.081, 148), (0.070, 10), (0.070, 17), (0.070, 149), (0.070, 151), (0.070, 150), (0.070, 14), (0.058, 11), (0.058, 15), (0.058, 13), (0.058, 122), (0.047, 121), (0.035, 123), (0.012, 92)]
247
+ else:
248
+ if context.get_value('k') <= 40.0:
249
+ return [(0.083, 42), (0.083, 46), (0.083, 44), (0.083, 40), (0.083, 128), (0.067, 45), (0.067, 43), (0.067, 41), (0.067, 169), (0.067, 171), (0.067, 168), (0.067, 129), (0.067, 170), (0.033, 103), (0.017, 121)]
250
+ else:
251
+ return [(0.112, 137), (0.104, 136), (0.101, 0), (0.081, 1), (0.073, 135), (0.069, 67), (0.066, 187), (0.058, 41), (0.050, 71), (0.046, 68), (0.046, 70), (0.031, 44), (0.027, 43), (0.027, 170), (0.019, 189), (0.019, 188), (0.015, 169), (0.015, 171), (0.012, 115), (0.012, 168), (0.012, 69), (0.004, 103)]
252
+ else:
253
+ if context.get_value('mat1_stride_0') <= 20.0:
254
+ return [(0.069, 0), (0.059, 157), (0.059, 22), (0.059, 153), (0.059, 155), (0.059, 25), (0.059, 23), (0.059, 19), (0.044, 21), (0.044, 18), (0.044, 152), (0.044, 158), (0.044, 154), (0.044, 156), (0.044, 20), (0.044, 124), (0.044, 24), (0.030, 125), (0.029, 126), (0.015, 97), (0.015, 95), (0.015, 96), (0.010, 2), (0.010, 75)]
255
+ else:
256
+ if context.get_value('k') <= 68.0:
257
+ return [(0.087, 72), (0.087, 74), (0.087, 73), (0.086, 76), (0.077, 75), (0.067, 192), (0.058, 190), (0.048, 47), (0.048, 193), (0.048, 49), (0.048, 51), (0.048, 191), (0.038, 53), (0.019, 133), (0.019, 50), (0.019, 175), (0.019, 172), (0.019, 48), (0.019, 174), (0.010, 173), (0.010, 177), (0.010, 52), (0.010, 54), (0.010, 178), (0.010, 176)]
258
+ else:
259
+ return [(0.154, 52), (0.154, 72), (0.102, 75), (0.087, 49), (0.087, 73), (0.086, 51), (0.057, 176), (0.045, 2), (0.038, 191), (0.038, 178), (0.038, 190), (0.029, 173), (0.029, 76), (0.026, 138), (0.013, 139), (0.013, 140), (0.003, 0)]
260
+ else:
261
+ if context.get_value('k') <= 35.0:
262
+ if context.get_value('k') <= 18.0:
263
+ if context.get_value('m*n') <= 19505152.0:
264
+ return [(0.151, 159), (0.140, 160), (0.129, 164), (0.055, 127), (0.051, 29), (0.044, 161), (0.044, 147), (0.040, 146), (0.040, 31), (0.037, 145), (0.026, 28), (0.022, 90), (0.022, 93), (0.022, 94), (0.022, 100), (0.022, 125), (0.022, 158), (0.022, 157), (0.011, 87), (0.011, 88), (0.011, 89), (0.011, 91), (0.011, 95), (0.011, 96), (0.011, 98), (0.011, 99)]
265
+ else:
266
+ return [(0.069, 7), (0.069, 5), (0.067, 147), (0.066, 8), (0.061, 145), (0.058, 146), (0.052, 124), (0.049, 29), (0.049, 159), (0.046, 31), (0.043, 157), (0.041, 9), (0.041, 4), (0.040, 6), (0.035, 164), (0.035, 160), (0.026, 158), (0.017, 125), (0.017, 28), (0.017, 32), (0.017, 162), (0.017, 27), (0.017, 30), (0.017, 161), (0.009, 33), (0.009, 26), (0.009, 163), (0.006, 0)]
267
+ else:
268
+ if context.get_value('n') <= 68.0:
269
+ return [(0.101, 182), (0.101, 59), (0.088, 57), (0.076, 184), (0.076, 61), (0.076, 179), (0.076, 62), (0.076, 58), (0.063, 180), (0.063, 60), (0.051, 56), (0.050, 181), (0.025, 130), (0.025, 177), (0.025, 183), (0.013, 178), (0.013, 55)]
270
+ else:
271
+ return [(0.089, 180), (0.079, 60), (0.066, 35), (0.066, 181), (0.066, 38), (0.066, 58), (0.066, 179), (0.066, 57), (0.062, 184), (0.053, 37), (0.044, 166), (0.040, 55), (0.040, 39), (0.040, 36), (0.040, 165), (0.040, 167), (0.027, 177), (0.027, 34), (0.022, 159)]
272
+ else:
273
+ if context.get_value('m*n') <= 309760.0:
274
+ return [(0.298, 0), (0.097, 140), (0.080, 83), (0.072, 86), (0.044, 84), (0.036, 178), (0.036, 117), (0.036, 82), (0.032, 120), (0.032, 85), (0.028, 119), (0.024, 130), (0.024, 109), (0.020, 108), (0.020, 118), (0.012, 104), (0.012, 116), (0.012, 141), (0.012, 144), (0.008, 105), (0.008, 106), (0.008, 111), (0.008, 114), (0.008, 107), (0.008, 132), (0.004, 101), (0.004, 102), (0.004, 110), (0.004, 112), (0.004, 113), (0.004, 131)]
275
+ else:
276
+ if context.get_value('n') <= 72.0:
277
+ return [(0.227, 77), (0.118, 78), (0.102, 194), (0.086, 80), (0.059, 57), (0.054, 81), (0.049, 196), (0.048, 197), (0.048, 59), (0.043, 79), (0.032, 195), (0.027, 180), (0.022, 3), (0.021, 141), (0.016, 60), (0.016, 142), (0.011, 183), (0.011, 0), (0.011, 144)]
278
+ else:
279
+ return [(0.140, 186), (0.132, 185), (0.109, 63), (0.085, 65), (0.078, 37), (0.077, 35), (0.062, 197), (0.047, 194), (0.046, 165), (0.046, 57), (0.039, 78), (0.039, 79), (0.039, 66), (0.039, 64), (0.016, 195), (0.008, 159)]
280
+ else:
281
+ if str(context.get_value('using_tf32')) != 'False':
282
+ if context.get_value('m*n') <= 815360.0:
283
+ if context.get_value('k') <= 1184.0:
284
+ return [(0.218, 140), (0.205, 0), (0.154, 144), (0.115, 141), (0.051, 185), (0.051, 104), (0.039, 78), (0.038, 116), (0.026, 165), (0.026, 130), (0.026, 178), (0.013, 57), (0.013, 195), (0.013, 167), (0.013, 186)]
285
+ else:
286
+ return [(0.901, 0), (0.030, 144), (0.030, 134), (0.016, 3), (0.006, 78), (0.006, 77), (0.002, 57), (0.002, 194), (0.002, 59), (0.002, 60), (0.002, 143)]
287
+ else:
288
+ if context.get_value('arith_intensity') <= 187.23922729492188:
289
+ if context.get_value('mat1_stride_0') <= 198.0:
290
+ return [(0.273, 63), (0.158, 37), (0.152, 35), (0.127, 57), (0.097, 165), (0.053, 185), (0.031, 0), (0.028, 64), (0.014, 60), (0.014, 78), (0.009, 55), (0.008, 134), (0.005, 34), (0.005, 167), (0.005, 179), (0.005, 65), (0.005, 66), (0.005, 186), (0.005, 194), (0.002, 166)]
291
+ else:
292
+ return [(0.296, 63), (0.235, 0), (0.132, 64), (0.074, 37), (0.069, 78), (0.051, 185), (0.051, 35), (0.030, 57), (0.020, 77), (0.016, 194), (0.008, 66), (0.007, 65), (0.003, 3), (0.003, 165), (0.003, 141), (0.001, 134), (0.001, 166)]
293
+ else:
294
+ return [(0.405, 0), (0.246, 37), (0.177, 63), (0.145, 35), (0.005, 185), (0.005, 65), (0.005, 64), (0.004, 57), (0.003, 66), (0.002, 165), (0.001, 78), (0.001, 55)]
295
+ else:
296
+ return [(0.357, 0), (0.112, 165), (0.101, 57), (0.094, 179), (0.086, 64), (0.074, 167), (0.067, 60), (0.064, 159), (0.033, 35), (0.007, 195), (0.002, 180), (0.001, 34), (0.001, 166), (0.001, 78)]
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MMRankingH100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 232448
27
+ and str(metadata.device_capa) == "(9, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=8')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=8')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
44
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
45
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
46
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
50
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=8')
51
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=8')
53
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=8')
57
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
58
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=8')
59
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
60
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
61
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=8')
62
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
63
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
64
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=8')
65
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
66
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
67
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
68
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
69
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
70
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
71
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
72
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
73
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=8')
74
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
75
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=4')
76
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=5_numwarps=8')
77
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
78
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=8')
79
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
80
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=8')
81
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
82
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=8')
83
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
84
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=8')
85
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=2')
86
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=8')
87
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
88
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
89
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
90
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
91
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
92
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
93
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=2')
94
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=8')
95
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
96
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
97
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
98
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
99
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
100
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
101
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
102
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
103
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
104
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
105
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
106
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=8')
107
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=8')
108
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
109
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=8')
110
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
111
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
112
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=8')
113
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
114
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=8')
115
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
116
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
117
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=8')
118
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
119
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=8')
120
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
121
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
122
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
123
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
124
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
125
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
126
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=2_numwarps=8')
127
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
128
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=8')
129
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
130
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
131
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=4')
132
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=128_numstages=5_numwarps=8')
133
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=1')
134
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=1')
135
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=1')
136
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
137
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=2')
138
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=2')
139
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=2')
140
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=2')
141
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=2')
142
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
143
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
144
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
145
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
146
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=2_numwarps=8')
147
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
148
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
149
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
150
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=1')
151
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=1')
152
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=2')
153
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=2')
154
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
155
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
156
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
157
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
158
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=8')
159
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=4')
160
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
161
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
162
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=16_numstages=2_numwarps=2')
163
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
164
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
165
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
166
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=1_numwarps=2')
167
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=2')
168
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=2')
169
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=1_numwarps=2')
170
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
171
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
172
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=8')
173
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=2')
174
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=2')
175
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
176
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
177
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
178
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=16_numstages=2_numwarps=2')
179
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=64_BLOCK-N=32_numstages=2_numwarps=4')
180
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
181
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=3_numwarps=4')
182
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=4_numwarps=4')
183
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=16_numstages=5_numwarps=4')
184
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=3_numwarps=4')
185
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=4_numwarps=4')
186
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
187
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=3_numwarps=4')
188
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=4_numwarps=4')
189
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
190
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=3_numwarps=4')
191
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=4')
192
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=128_numstages=4_numwarps=8')
193
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=2_numwarps=4')
194
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=3_numwarps=4')
195
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=4_numwarps=4')
196
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=16_numstages=5_numwarps=4')
197
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=2_numwarps=4')
198
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=3_numwarps=4')
199
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=4_numwarps=4')
200
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=4')
201
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=32_numstages=5_numwarps=8')
202
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=2_numwarps=4')
203
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=4')
204
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=3_numwarps=8')
205
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=4')
206
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=4_numwarps=8')
207
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=16_BLOCK-N=64_numstages=5_numwarps=4')
208
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
209
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=4')
210
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=128_numstages=4_numwarps=8')
211
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=2_numwarps=4')
212
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=3_numwarps=4')
213
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=4_numwarps=4')
214
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=16_numstages=5_numwarps=4')
215
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=2_numwarps=4')
216
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=4')
217
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=3_numwarps=8')
218
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=4')
219
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=4_numwarps=8')
220
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=4')
221
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=32_numstages=5_numwarps=8')
222
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=2_numwarps=4')
223
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
224
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=8')
225
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=4')
226
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=4_numwarps=8')
227
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=4')
228
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
229
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=4_numwarps=4')
230
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=3_numwarps=4')
231
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=4_numwarps=4')
232
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=16_numstages=5_numwarps=4')
233
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=4')
234
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=3_numwarps=8')
235
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=4_numwarps=4')
236
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=32_numstages=5_numwarps=4')
237
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
238
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
239
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=4_numwarps=4')
240
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=5_numwarps=4')
241
+
242
+ def get_name(self) -> str:
243
+ return 'mm'
244
+
245
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
246
+ if context.get_value('arith_intensity') <= 29.89772129058838:
247
+ if context.get_value('n') <= 34.0:
248
+ if context.get_value('n') <= 18.0:
249
+ if context.get_value('k*n') <= 432.0:
250
+ if context.get_value('arith_intensity') <= 7.8700292110443115:
251
+ return [(0.098, 128), (0.098, 129), (0.098, 127), (0.073, 14), (0.073, 16), (0.073, 12), (0.073, 154), (0.073, 156), (0.073, 157), (0.073, 155), (0.049, 10), (0.049, 94), (0.049, 95), (0.048, 96)]
252
+ else:
253
+ return [(0.091, 154), (0.073, 10), (0.073, 15), (0.073, 13), (0.073, 11), (0.073, 17), (0.073, 16), (0.073, 14), (0.073, 12), (0.055, 127), (0.054, 157), (0.054, 156), (0.054, 155), (0.036, 129), (0.036, 128), (0.018, 41), (0.018, 43)]
254
+ else:
255
+ if context.get_value('k') <= 40.0:
256
+ return [(0.070, 39), (0.069, 45), (0.069, 41), (0.069, 43), (0.069, 111), (0.069, 112), (0.056, 38), (0.056, 40), (0.056, 42), (0.056, 44), (0.056, 174), (0.056, 173), (0.056, 175), (0.056, 134), (0.056, 172), (0.056, 135), (0.014, 154), (0.014, 127)]
257
+ else:
258
+ return [(0.147, 144), (0.119, 143), (0.087, 142), (0.083, 0), (0.073, 191), (0.059, 69), (0.050, 67), (0.046, 70), (0.041, 1), (0.036, 174), (0.032, 43), (0.032, 123), (0.028, 40), (0.027, 42), (0.027, 173), (0.023, 175), (0.018, 66), (0.014, 192), (0.014, 193), (0.014, 139), (0.014, 68), (0.014, 127)]
259
+ else:
260
+ if context.get_value('mat1_stride_0') <= 40.0:
261
+ if context.get_value('mat1_stride_0') <= 20.0:
262
+ return [(0.109, 23), (0.109, 21), (0.109, 20), (0.088, 0), (0.087, 131), (0.066, 18), (0.065, 130), (0.065, 132), (0.065, 159), (0.065, 160), (0.065, 161), (0.065, 158), (0.022, 22), (0.022, 19)]
263
+ else:
264
+ return [(0.065, 46), (0.064, 52), (0.064, 50), (0.064, 48), (0.064, 51), (0.064, 49), (0.064, 47), (0.064, 53), (0.064, 181), (0.064, 177), (0.064, 179), (0.064, 176), (0.038, 130), (0.038, 136), (0.026, 182), (0.026, 178), (0.026, 180), (0.026, 137), (0.025, 158), (0.013, 114), (0.013, 113)]
265
+ else:
266
+ if context.get_value('mat1_stride_0') <= 68.0:
267
+ return [(0.138, 140), (0.125, 195), (0.100, 71), (0.100, 74), (0.100, 196), (0.100, 194), (0.100, 197), (0.075, 75), (0.062, 72), (0.062, 73), (0.012, 180), (0.012, 51), (0.012, 182)]
268
+ else:
269
+ return [(0.124, 180), (0.124, 182), (0.114, 75), (0.103, 74), (0.093, 51), (0.093, 71), (0.072, 72), (0.062, 194), (0.052, 145), (0.052, 195), (0.021, 48), (0.021, 50), (0.021, 47), (0.020, 124), (0.010, 147), (0.010, 146), (0.010, 46)]
270
+ else:
271
+ if context.get_value('k') <= 18.0:
272
+ if context.get_value('m*k') <= 528.0:
273
+ return [(0.097, 88), (0.087, 92), (0.077, 90), (0.058, 105), (0.058, 103), (0.058, 104), (0.058, 99), (0.058, 100), (0.058, 106), (0.058, 93), (0.057, 91), (0.057, 97), (0.057, 98), (0.057, 101), (0.048, 102), (0.029, 87), (0.029, 89)]
274
+ else:
275
+ if context.get_value('n') <= 80.0:
276
+ return [(0.057, 161), (0.057, 130), (0.057, 24), (0.056, 164), (0.056, 163), (0.056, 166), (0.056, 168), (0.056, 30), (0.056, 28), (0.056, 26), (0.056, 25), (0.056, 27), (0.056, 29), (0.056, 31), (0.042, 131), (0.028, 99), (0.028, 101), (0.028, 100), (0.028, 167), (0.028, 165), (0.028, 133)]
277
+ else:
278
+ return [(0.110, 164), (0.108, 163), (0.106, 168), (0.069, 161), (0.066, 151), (0.060, 152), (0.055, 165), (0.050, 27), (0.050, 29), (0.048, 131), (0.043, 153), (0.037, 133), (0.037, 130), (0.028, 8), (0.028, 5), (0.027, 7), (0.026, 26), (0.016, 162), (0.012, 9), (0.007, 4), (0.005, 100), (0.005, 6), (0.005, 24)]
279
+ else:
280
+ if context.get_value('k') <= 36.0:
281
+ if context.get_value('n') <= 68.0:
282
+ return [(0.097, 184), (0.097, 56), (0.086, 186), (0.086, 183), (0.086, 188), (0.086, 58), (0.086, 60), (0.065, 54), (0.043, 187), (0.043, 185), (0.043, 57), (0.043, 61), (0.032, 55), (0.032, 130), (0.032, 59), (0.011, 181), (0.011, 163), (0.011, 136), (0.011, 138)]
283
+ else:
284
+ return [(0.117, 184), (0.117, 170), (0.117, 169), (0.107, 183), (0.106, 188), (0.075, 181), (0.064, 130), (0.064, 56), (0.053, 171), (0.032, 57), (0.032, 59), (0.032, 185), (0.011, 163), (0.011, 32), (0.011, 37), (0.011, 34), (0.011, 33), (0.011, 35), (0.011, 36), (0.011, 54)]
285
+ else:
286
+ if context.get_value('mat2_stride_0') <= 384.0:
287
+ return [(0.244, 0), (0.061, 76), (0.061, 79), (0.030, 3), (0.030, 183), (0.030, 189), (0.030, 187), (0.030, 64), (0.030, 190), (0.030, 62), (0.030, 198), (0.030, 201), (0.030, 77), (0.030, 200), (0.030, 80), (0.030, 199), (0.030, 78), (0.030, 184), (0.020, 86), (0.020, 84), (0.020, 120), (0.020, 81), (0.020, 121), (0.020, 85), (0.020, 122), (0.010, 83), (0.010, 118), (0.010, 119), (0.010, 82)]
288
+ else:
289
+ return [(0.274, 83), (0.171, 86), (0.152, 0), (0.071, 85), (0.061, 125), (0.050, 84), (0.020, 109), (0.020, 117), (0.020, 81), (0.020, 118), (0.020, 121), (0.020, 108), (0.020, 115), (0.020, 116), (0.010, 110), (0.010, 120), (0.010, 103), (0.010, 107), (0.010, 119), (0.010, 122)]
290
+ else:
291
+ if context.get_value('arith_intensity') <= 56.995582580566406:
292
+ if context.get_value('n') <= 68.0:
293
+ if context.get_value('k*n') <= 4448.0:
294
+ if context.get_value('m*n') <= 29626368.0:
295
+ return [(0.107, 198), (0.107, 200), (0.107, 201), (0.107, 199), (0.106, 76), (0.106, 79), (0.064, 197), (0.063, 56), (0.043, 184), (0.043, 187), (0.042, 80), (0.042, 77), (0.042, 183), (0.021, 78)]
296
+ else:
297
+ return [(0.073, 201), (0.073, 198), (0.073, 200), (0.073, 199), (0.073, 197), (0.073, 56), (0.073, 58), (0.073, 79), (0.073, 76), (0.072, 59), (0.072, 78), (0.072, 77), (0.072, 80), (0.018, 184), (0.018, 55), (0.018, 54)]
298
+ else:
299
+ if context.get_value('k') <= 348.0:
300
+ return [(0.206, 76), (0.183, 77), (0.169, 198), (0.160, 199), (0.053, 59), (0.046, 56), (0.038, 3), (0.030, 148), (0.030, 58), (0.030, 187), (0.023, 184), (0.015, 0), (0.008, 55), (0.008, 54)]
301
+ else:
302
+ return [(0.146, 198), (0.145, 199), (0.145, 148), (0.126, 0), (0.084, 76), (0.084, 77), (0.042, 80), (0.042, 79), (0.021, 149), (0.021, 150), (0.021, 3), (0.014, 46), (0.014, 74), (0.014, 75), (0.014, 124), (0.014, 194), (0.014, 195), (0.007, 145), (0.007, 146), (0.007, 2), (0.007, 72), (0.007, 147), (0.007, 71)]
303
+ else:
304
+ if context.get_value('m') <= 3264.0:
305
+ return [(0.247, 147), (0.115, 197), (0.066, 199), (0.066, 201), (0.066, 198), (0.049, 0), (0.049, 169), (0.049, 171), (0.033, 140), (0.033, 125), (0.033, 114), (0.016, 126), (0.016, 183), (0.016, 184), (0.016, 185), (0.016, 182), (0.016, 188), (0.016, 78), (0.016, 148), (0.016, 138), (0.016, 77), (0.016, 56), (0.016, 59)]
306
+ else:
307
+ if context.get_value('k') <= 62.5:
308
+ return [(0.226, 190), (0.226, 189), (0.122, 62), (0.122, 64), (0.055, 77), (0.055, 78), (0.037, 198), (0.036, 201), (0.036, 33), (0.024, 163), (0.018, 56), (0.018, 35), (0.018, 169), (0.006, 171)]
309
+ else:
310
+ return [(0.162, 35), (0.118, 33), (0.096, 189), (0.096, 190), (0.088, 169), (0.074, 62), (0.073, 56), (0.066, 171), (0.051, 198), (0.051, 201), (0.044, 59), (0.037, 64), (0.029, 63), (0.007, 0), (0.007, 77)]
311
+ else:
312
+ if context.get_value('m*n') <= 1097728.0:
313
+ return [(0.403, 0), (0.179, 141), (0.134, 150), (0.086, 147), (0.051, 148), (0.048, 3), (0.024, 189), (0.020, 199), (0.017, 64), (0.010, 65), (0.010, 77), (0.007, 114), (0.003, 138), (0.003, 59), (0.003, 182)]
314
+ else:
315
+ if context.get_value('m*n') <= 3244032.0:
316
+ return [(0.295, 189), (0.176, 64), (0.157, 65), (0.090, 0), (0.069, 62), (0.059, 63), (0.046, 77), (0.039, 169), (0.023, 199), (0.020, 35), (0.013, 33), (0.010, 171), (0.003, 141)]
317
+ else:
318
+ if context.get_value('n') <= 136.0:
319
+ return [(0.197, 189), (0.197, 63), (0.161, 77), (0.157, 62), (0.061, 33), (0.044, 65), (0.039, 35), (0.039, 64), (0.030, 169), (0.026, 0), (0.017, 199), (0.017, 148), (0.009, 56), (0.004, 3)]
320
+ else:
321
+ return [(0.460, 0), (0.145, 62), (0.138, 63), (0.081, 35), (0.047, 33), (0.043, 189), (0.023, 64), (0.018, 77), (0.013, 169), (0.009, 65), (0.009, 56), (0.005, 32), (0.005, 59), (0.002, 183), (0.002, 163)]
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
5
+ from typing import List, Optional, Tuple
6
+
7
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
8
+ AHContext,
9
+ AHMetadata,
10
+ Choice,
11
+ )
12
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
13
+ LearnedHeuristicDecision,
14
+ )
15
+
16
+
17
+ class MixedMMH100(LearnedHeuristicDecision):
18
+
19
+ def __init__(self) -> None:
20
+ self.choices: List[Choice] = []
21
+ self.fill_choices()
22
+
23
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
24
+ return (
25
+ metadata.name == self.get_name()
26
+ and metadata.shared_memory == 232448
27
+ and str(metadata.device_capa) == "(9, 0)"
28
+ )
29
+
30
+ def get_confidence_threshold(self) -> float:
31
+ return 0.0
32
+
33
+ def get_choice(self, idx: int) -> Optional[str]:
34
+ if idx < len(self.choices):
35
+ return self.choices[idx]
36
+ return None
37
+
38
+ def fill_choices(self) -> None:
39
+ self.choices.append('extern_fallback_mixed_mm')
40
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=128_numstages=3_numwarps=4')
41
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=32_BLOCK-N=64_numstages=3_numwarps=4')
42
+ self.choices.append('type=triton_BLOCK-M=128_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
43
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
44
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=2')
45
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=2')
46
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
47
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=3_numwarps=4')
48
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=256_BLOCK-N=128_numstages=5_numwarps=8')
49
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=128_numstages=5_numwarps=8')
50
+ self.choices.append('type=triton_BLOCK-M=16_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=4')
51
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
52
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=2_numwarps=4')
53
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
54
+ self.choices.append('type=triton_BLOCK-M=32_BLOCK-K=32_BLOCK-N=64_numstages=5_numwarps=8')
55
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=128_numstages=4_numwarps=4')
56
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=32_numstages=5_numwarps=4')
57
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=128_BLOCK-N=64_numstages=5_numwarps=4')
58
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=128_numstages=3_numwarps=4')
59
+ self.choices.append('type=triton_BLOCK-M=64_BLOCK-K=64_BLOCK-N=64_numstages=3_numwarps=8')
60
+
61
+ def get_name(self) -> str:
62
+ return 'mixed_mm'
63
+
64
+ def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
65
+ if context.get_value('arith_intensity') <= 15.988086223602295:
66
+ if context.get_value('n') <= 25280.0:
67
+ if context.get_value('n') <= 1344.0:
68
+ if context.get_value('mat1_stride_0') <= 7808.0:
69
+ return [(0.581, 7), (0.419, 6)]
70
+ else:
71
+ if context.get_value('m*n') <= 7680.0:
72
+ return [(0.875, 0), (0.125, 6)]
73
+ else:
74
+ return [(0.833, 0), (0.167, 7)]
75
+ else:
76
+ if context.get_value('n') <= 8512.0:
77
+ if str(context.get_value('mat2_dtype')) != 'torch.int8':
78
+ return [(0.763, 6), (0.237, 7)]
79
+ else:
80
+ return [(0.725, 7), (0.275, 6)]
81
+ else:
82
+ if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
83
+ return [(0.736, 7), (0.197, 9), (0.048, 6), (0.014, 8), (0.005, 10)]
84
+ else:
85
+ return [(0.473, 7), (0.398, 6), (0.097, 9), (0.032, 10)]
86
+ else:
87
+ if context.get_value('n') <= 42254.0:
88
+ if context.get_value('n') <= 33856.0:
89
+ if context.get_value('k*n') <= 68157440.0:
90
+ return [(0.370, 4), (0.370, 5), (0.074, 7), (0.074, 8), (0.074, 11), (0.037, 6)]
91
+ else:
92
+ return [(0.916, 8), (0.036, 7), (0.036, 9), (0.012, 4)]
93
+ else:
94
+ return [(0.659, 5), (0.341, 6)]
95
+ else:
96
+ if context.get_value('k*n') <= 326052992.0:
97
+ if context.get_value('n') <= 55232.0:
98
+ return [(0.571, 6), (0.321, 7), (0.036, 4), (0.036, 8), (0.036, 9)]
99
+ else:
100
+ return [(0.506, 6), (0.325, 8), (0.104, 7), (0.039, 5), (0.026, 9)]
101
+ else:
102
+ if context.get_value('n') <= 57024.0:
103
+ return [(0.462, 9), (0.385, 7), (0.115, 6), (0.038, 8)]
104
+ else:
105
+ return [(0.598, 8), (0.223, 9), (0.107, 6), (0.071, 7)]
106
+ else:
107
+ if context.get_value('m*n') <= 543936.0:
108
+ if str(context.get_value('17LEQmLEQ32')) != 'True':
109
+ if context.get_value('m*n') <= 262272.0:
110
+ if context.get_value('n') <= 1592.5:
111
+ return [(0.860, 0), (0.140, 9)]
112
+ else:
113
+ return None
114
+ else:
115
+ if context.get_value('m*k') <= 1294336.0:
116
+ return [(0.833, 17), (0.150, 18), (0.017, 15)]
117
+ else:
118
+ return [(0.917, 17), (0.083, 8)]
119
+ else:
120
+ if context.get_value('n') <= 12416.0:
121
+ if context.get_value('m*n') <= 43008.0:
122
+ return None
123
+ else:
124
+ return [(0.853, 14), (0.147, 9)]
125
+ else:
126
+ return [(0.625, 12), (0.375, 14)]
127
+ else:
128
+ if context.get_value('m') <= 32.5:
129
+ if context.get_value('mat2_stride_1') <= 6656.0:
130
+ if context.get_value('n') <= 69184.0:
131
+ return [(0.611, 12), (0.361, 14), (0.028, 13)]
132
+ else:
133
+ return [(1.000, 12)]
134
+ else:
135
+ if context.get_value('mat2_stride_1') <= 20864.0:
136
+ return [(1.000, 12)]
137
+ else:
138
+ return [(0.958, 12), (0.042, 9)]
139
+ else:
140
+ if context.get_value('m*n') <= 1085440.0:
141
+ if context.get_value('n') <= 9152.0:
142
+ return [(1.000, 18)]
143
+ else:
144
+ return [(0.780, 18), (0.160, 16), (0.060, 20)]
145
+ else:
146
+ if context.get_value('m') <= 67.0:
147
+ return [(0.650, 16), (0.203, 19), (0.122, 18), (0.016, 20), (0.008, 1)]
148
+ else:
149
+ return [(0.561, 3), (0.185, 16), (0.096, 20), (0.083, 19), (0.076, 2)]
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: B950
2
+ # fmt: off
3
+ # This file was generated by AutoHeuristic. Do not modify it manually!
4
+ # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
5
+ from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL
6
+ from torch._inductor.autoheuristic.learnedheuristic_interface import (
7
+ LearnedHeuristicRegression,
8
+ )
9
+
10
+
11
+ class PadMMA100(LearnedHeuristicRegression):
12
+
13
+ def __init__(self) -> None:
14
+ pass
15
+
16
+ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
17
+ return (
18
+ metadata.name == self.get_name()
19
+ and metadata.shared_memory == 166912
20
+ and str(metadata.device_capa) == "(8, 0)"
21
+ )
22
+
23
+ def get_feedback(self, context: AHContext, choice: Choice) -> float:
24
+ context.context_dict[CHOICE_COL] = choice
25
+ return self.predict(context)
26
+
27
+ def get_confidence_threshold(self) -> float:
28
+ return 1.7025303314066
29
+
30
+ def get_name(self) -> str:
31
+ return 'pad_mm'
32
+
33
+ def predict(self, context: AHContext) -> float:
34
+ if str(context.get_value('choice')) != 'pad':
35
+ if str(context.get_value('using_tf32')) != 'False':
36
+ if context.get_value('m*n') <= 4171264.0:
37
+ if context.get_value('m*k') <= 3999308.0:
38
+ return 1.8751469764071178
39
+ else:
40
+ if str(context.get_value('n_multiple_32')) != 'True':
41
+ return 0.9117231355626345
42
+ else:
43
+ return 1.1607689608873861
44
+ else:
45
+ if str(context.get_value('n_multiple_2')) != 'True':
46
+ if str(context.get_value('using_tf32')) != 'True':
47
+ return 0.7430382200435992
48
+ else:
49
+ return 0.8531269794448678
50
+ else:
51
+ if str(context.get_value('k_multiple_2')) != 'True':
52
+ return 0.7577181972719917
53
+ else:
54
+ return 0.8977349440424219
55
+ else:
56
+ if context.get_value('m*n') <= 1299712.0:
57
+ return 1.1669723418995592
58
+ else:
59
+ if context.get_value('mat2_stride_1') <= 45217.5:
60
+ if context.get_value('m*n') <= 55884158.0:
61
+ return 1.0262769936909601
62
+ else:
63
+ return 1.0022677428470845
64
+ else:
65
+ if context.get_value('m') <= 18478.0:
66
+ return 1.1127066261894312
67
+ else:
68
+ return 1.0337740659894263
69
+ else:
70
+ if str(context.get_value('mat1_dtype')) != 'torch.float32':
71
+ if str(context.get_value('n_multiple_2')) != 'False':
72
+ if str(context.get_value('k_multiple_2')) != 'True':
73
+ if context.get_value('mat1_stride_0') <= 561.0:
74
+ return 1.2900382135142956
75
+ else:
76
+ return 1.5761737616057887
77
+ else:
78
+ if context.get_value('num_dims_needs_padding') <= 1.5:
79
+ return 1.0472263310239422
80
+ else:
81
+ return 1.1727673465762514
82
+ else:
83
+ if context.get_value('k') <= 28238.5:
84
+ if context.get_value('k/(m*n)') <= 0.00026227018679492176:
85
+ return 1.6770542505397175
86
+ else:
87
+ return 1.3974785435105923
88
+ else:
89
+ if str(context.get_value('mat1_dtype')) != 'torch.bfloat16':
90
+ return 1.3952699800111992
91
+ else:
92
+ return 1.5759286511628336
93
+ else:
94
+ if str(context.get_value('using_tf32')) != 'False':
95
+ if context.get_value('m*n') <= 14119424.0:
96
+ return 0.8875772670422478
97
+ else:
98
+ if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
99
+ return 1.1467728924377265
100
+ else:
101
+ return 1.215842963532998
102
+ else:
103
+ if context.get_value('arith_intensity') <= 396.8774871826172:
104
+ return 0.89940161869551
105
+ else:
106
+ if context.get_value('mat2_stride_1') <= 45217.5:
107
+ return 0.9964328169353532
108
+ else:
109
+ return 0.9493479238294826
.venv/Lib/site-packages/torch/_inductor/autoheuristic/artifacts/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/_inductor/codegen/aoti_hipify_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import re
3
+
4
+ import torch
5
+ from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE
6
+
7
+
8
+ # It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like:
9
+ # "...
10
+ # from ..codecache import CudaKernelParamCache
11
+ # ..."
12
+ # In such cases, we do not need to hipify_torch the orignial class/file name in codegen/codecache
13
+
14
+
15
+ def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> str:
16
+ if torch.version.hip is None and not force_hipify:
17
+ return source_codes
18
+
19
+ def c2_repl(m):
20
+ return PYTORCH_MAP[m.group(0)]
21
+
22
+ # We need to redefine RE_PYTORCH_PREPROCESSOR here since in hipify_torch,
23
+ # it will apply positive lookbehind (?<=\W) to the pattern to avoid matching
24
+ # keyword at the beginning of code line. However, this can happen in codegen,
25
+ # which will cause the pattern to not match.
26
+
27
+ # Note that lookahead (?=\W) is still needed to keep hipification idomponent, for example
28
+ # we need to skip replacing "getStreamFromExternal" in "getStreamFromExternalMasqueradingAsCUDA"
29
+ RE_PYTORCH_PREPROCESSOR = re.compile(rf"({PYTORCH_TRIE.export_to_regex()})(?=\W)")
30
+
31
+ source_codes = RE_PYTORCH_PREPROCESSOR.sub(c2_repl, source_codes)
32
+ return source_codes
.venv/Lib/site-packages/torch/_inductor/codegen/codegen_device_driver.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ # Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose
5
+
6
+
7
+ def cuda_kernel_driver() -> str:
8
+ source_codes = """
9
+ #define CUDA_DRIVER_CHECK(EXPR) \\
10
+ do { \\
11
+ CUresult code = EXPR; \\
12
+ const char *msg; \\
13
+ cuGetErrorString(code, &msg); \\
14
+ if (code != CUDA_SUCCESS) { \\
15
+ throw std::runtime_error( \\
16
+ std::string("CUDA driver error: ") + \\
17
+ std::string(msg)); \\
18
+ } \\
19
+ } while (0);
20
+
21
+ namespace {
22
+
23
+ struct Grid {
24
+ Grid(uint32_t x, uint32_t y, uint32_t z)
25
+ : grid_x(x), grid_y(y), grid_z(z) {}
26
+ uint32_t grid_x;
27
+ uint32_t grid_y;
28
+ uint32_t grid_z;
29
+
30
+ bool is_non_zero() {
31
+ return grid_x > 0 && grid_y > 0 && grid_z > 0;
32
+ }
33
+ };
34
+
35
+ } // anonymous namespace
36
+
37
+ static inline CUfunction loadKernel(
38
+ std::string filePath,
39
+ const std::string &funcName,
40
+ uint32_t sharedMemBytes,
41
+ const std::optional<std::string> &cubinDir = std::nullopt) {
42
+ if (cubinDir) {
43
+ std::filesystem::path p1{*cubinDir};
44
+ std::filesystem::path p2{filePath};
45
+ filePath = (p1 / p2.filename()).string();
46
+ }
47
+
48
+ CUmodule mod;
49
+ CUfunction func;
50
+ CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
51
+ CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
52
+ if (sharedMemBytes > 0) {
53
+ CUDA_DRIVER_CHECK(cuFuncSetAttribute(
54
+ func,
55
+ CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
56
+ sharedMemBytes
57
+ ))
58
+ }
59
+ return func;
60
+ }
61
+
62
+ static inline void launchKernel(
63
+ CUfunction func,
64
+ uint32_t gridX,
65
+ uint32_t gridY,
66
+ uint32_t gridZ,
67
+ uint32_t numWarps,
68
+ uint32_t sharedMemBytes,
69
+ void* args[],
70
+ cudaStream_t stream) {
71
+ CUDA_DRIVER_CHECK(cuLaunchKernel(
72
+ func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
73
+ ));
74
+ }
75
+ """
76
+ if torch.version.hip is not None:
77
+ # Adjusting the warp size to GPU supported wavefront size on AMD GPU
78
+ prop = torch.cuda.get_device_properties(torch.cuda.current_device())
79
+ source_codes = source_codes.replace(
80
+ "32*numWarps", str(prop.warp_size) + "*numWarps"
81
+ )
82
+ return source_codes
83
+
84
+
85
+ def cuda_kernel_header() -> str:
86
+ source_codes = """
87
+ #include <c10/cuda/CUDAGuard.h>
88
+ #include <c10/cuda/CUDAStream.h>
89
+ #include <ATen/cuda/EmptyTensor.h>
90
+ """
91
+ return source_codes
.venv/Lib/site-packages/torch/_inductor/codegen/common.py ADDED
@@ -0,0 +1,2167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import dataclasses
4
+ import functools
5
+ import itertools
6
+ import logging
7
+ import math
8
+ import operator
9
+ import re
10
+ from enum import auto, Enum
11
+ from itertools import chain
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ ClassVar,
16
+ Dict,
17
+ List,
18
+ NamedTuple,
19
+ Optional,
20
+ Tuple,
21
+ Union,
22
+ )
23
+
24
+ import sympy
25
+ from sympy.printing.printer import Printer
26
+
27
+ import torch
28
+ import torch.fx
29
+ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
30
+ from torch.utils import _pytree as pytree
31
+ from torch.utils._ordered_set import OrderedSet
32
+ from torch.utils._sympy.numbers import int_oo
33
+ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
34
+ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
35
+
36
+ from .. import config, metrics
37
+ from ..utils import (
38
+ DeferredLineBase,
39
+ generate_assert,
40
+ IndentedBuffer,
41
+ sympy_dot,
42
+ sympy_subs,
43
+ unique,
44
+ )
45
+ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
46
+
47
+
48
+ schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
49
+
50
+
51
+ def data_type_logger(msg):
52
+ if schedule_log.isEnabledFor(logging.DEBUG):
53
+ schedule_log.debug("Data type propagation: %s", msg)
54
+
55
+
56
+ @dataclasses.dataclass
57
+ class WorkspaceArg:
58
+ """A temporary buffer used for a single kernel, then discarded.
59
+
60
+ Not registered as a traditional buffer since there are no users,
61
+ so it would be dead code eliminated.
62
+ """
63
+
64
+ nbytes: sympy.Expr
65
+ zero_fill: bool
66
+
67
+
68
+ @dataclasses.dataclass
69
+ class TensorArg:
70
+ name: str
71
+ buffer: str
72
+ dtype: torch.dtype
73
+ offset: sympy.Expr = sympy.Integer(0) # c++ only
74
+ alias_of: Optional[str] = None # halide only
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class SizeArg:
79
+ name: str
80
+ expr: sympy.Expr
81
+
82
+ @property
83
+ def alias_of(self):
84
+ return None
85
+
86
+
87
+ @dataclasses.dataclass
88
+ class DeviceCodegen:
89
+ scheduling: Any
90
+ wrapper_codegen: type
91
+ cpp_wrapper_codegen: type = type(None)
92
+
93
+
94
+ KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
95
+
96
+ device_codegens: Dict[str, DeviceCodegen] = {}
97
+
98
+
99
+ class DeviceOpOverrides:
100
+ def import_get_raw_stream_as(self, name):
101
+ raise NotImplementedError
102
+
103
+ def set_device(self, device_idx):
104
+ raise NotImplementedError
105
+
106
+ def synchronize(self):
107
+ raise NotImplementedError
108
+
109
+ def device_guard(self, device_idx):
110
+ raise NotImplementedError
111
+
112
+
113
+ device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
114
+
115
+
116
+ # The code generated by Inductor consists of two main parts: kernel code and wrapper code.
117
+ # For any new backend looking to integrate with Inductor, customization of these two main
118
+ # parts are necessary to generate its specific code.
119
+ #
120
+ # Kernel code generation is determined by different Scheduling. Consequently, a new
121
+ # backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
122
+ # CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
123
+ #
124
+ # For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
125
+ # that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
126
+ # and override specific member functions to create backend-specific Python wrapper code.
127
+ #
128
+ # Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
129
+ # of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
130
+ # provide flexibility to the backend. A backend can choose to implement these classes from scratch,
131
+ # or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
132
+ # register_backend_for_device, to equip a new backend at runtime.
133
+ #
134
+ # Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
135
+ # This backend can be used as a reference:
136
+ # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
137
+ def register_backend_for_device(
138
+ device: str,
139
+ device_scheduling: Any,
140
+ device_wrapper_codegen: type,
141
+ device_cpp_wrapper_codegen: type = type(None),
142
+ ):
143
+ device_codegens[device] = DeviceCodegen(
144
+ device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
145
+ )
146
+
147
+
148
+ class BackendFeature(Enum):
149
+ FOREACH = auto()
150
+ BUCKETIZE = auto()
151
+ INPLACE_BUFFERS = auto()
152
+ MASKED_SCATTER_WITH_INDEX = auto()
153
+ SCAN = auto()
154
+ SORT = auto()
155
+ TUPLE_REDUCTION = auto()
156
+ PREFER_STORE_LOOP_ORDER = auto()
157
+ TRITON_TEMPLATES = auto()
158
+ REDUCE_TO_SINGLE_ELEMENT = auto()
159
+
160
+
161
+ def get_backend_features(device: Union[torch.device, str]):
162
+ init_backend_registration()
163
+ if isinstance(device, torch.device):
164
+ device_type = device.type
165
+ else:
166
+ assert isinstance(device, str)
167
+ device_type = device
168
+ device = torch.device(device_type)
169
+ scheduling = get_scheduling_for_device(device_type)
170
+ return scheduling(None).get_backend_features(device)
171
+
172
+
173
+ def has_backend_feature(device, feature):
174
+ """See also V.graph.has_feature"""
175
+ assert isinstance(feature, BackendFeature)
176
+ return feature in get_backend_features(device)
177
+
178
+
179
+ def get_scheduling_for_device(device: str):
180
+ return device_codegens[device].scheduling if device in device_codegens else None
181
+
182
+
183
+ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
184
+ if device in device_codegens:
185
+ wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
186
+ return (
187
+ wrapper_codegen_obj.cpp_wrapper_codegen
188
+ if cpp_wrapper
189
+ else wrapper_codegen_obj.wrapper_codegen
190
+ )
191
+ else:
192
+ return None
193
+
194
+
195
+ @functools.lru_cache(None)
196
+ def init_backend_registration():
197
+ from .cpp import CppScheduling
198
+ from .cpp_wrapper_cpu import CppWrapperCpu
199
+ from .cpp_wrapper_cuda import CppWrapperCuda
200
+ from .cuda_combined_scheduling import CUDACombinedScheduling
201
+ from .halide import HalideScheduling
202
+ from .triton import TritonScheduling
203
+ from .wrapper import WrapperCodeGen
204
+
205
+ if get_scheduling_for_device("cpu") is None:
206
+ cpu_backends = {"cpp": CppScheduling, "halide": HalideScheduling}
207
+ register_backend_for_device(
208
+ "cpu",
209
+ lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs),
210
+ WrapperCodeGen,
211
+ CppWrapperCpu,
212
+ )
213
+
214
+ if get_scheduling_for_device("cuda") is None:
215
+ # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
216
+ cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
217
+ register_backend_for_device(
218
+ "cuda",
219
+ lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
220
+ WrapperCodeGen,
221
+ CppWrapperCuda,
222
+ )
223
+
224
+ if get_scheduling_for_device("xpu") is None:
225
+ register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
226
+
227
+ private_backend = torch._C._get_privateuse1_backend_name()
228
+ if (
229
+ private_backend != "privateuseone"
230
+ and get_scheduling_for_device(private_backend) is None
231
+ ):
232
+ from torch.utils.backend_registration import _get_custom_mod_func
233
+
234
+ try:
235
+ device_scheduling = _get_custom_mod_func("Scheduling")
236
+ wrapper_codegen = _get_custom_mod_func("WrapperCodeGen")
237
+ cpp_wrapper_codegen = _get_custom_mod_func("CppWrapperCodeGen")
238
+ if device_scheduling and wrapper_codegen and cpp_wrapper_codegen:
239
+ register_backend_for_device(
240
+ private_backend,
241
+ device_scheduling,
242
+ wrapper_codegen,
243
+ cpp_wrapper_codegen,
244
+ )
245
+ except RuntimeError:
246
+ pass
247
+
248
+
249
+ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
250
+ from ..ir import FlexibleLayout
251
+
252
+ # added contiguous index prevents reordering
253
+ return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
254
+
255
+
256
+ def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
257
+ device_op_overrides_dict[device] = device_op_overrides
258
+
259
+
260
+ def get_device_op_overrides(device: str):
261
+ assert isinstance(device, str)
262
+
263
+ if not device_op_overrides_dict.keys():
264
+ from .cuda import device_op_overrides # noqa: F401
265
+ from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
266
+
267
+ if device in device_op_overrides_dict.keys():
268
+ return device_op_overrides_dict[device]
269
+
270
+
271
+ @functools.lru_cache(None)
272
+ def boolean_ops():
273
+ return (
274
+ "isinf",
275
+ "isnan",
276
+ "logical_not",
277
+ "signbit",
278
+ "le",
279
+ "lt",
280
+ "ge",
281
+ "gt",
282
+ "eq",
283
+ "ne",
284
+ )
285
+
286
+
287
+ DTYPE_TO_COMPUTATION_DTYPE = {
288
+ torch.bfloat16: torch.float,
289
+ torch.float16: torch.float,
290
+ **{
291
+ dtype: dtype
292
+ for dtype in [
293
+ torch.bool,
294
+ torch.float32,
295
+ torch.float64,
296
+ torch.int8,
297
+ torch.int16,
298
+ torch.int32,
299
+ torch.int64,
300
+ torch.uint8,
301
+ torch.uint16,
302
+ torch.uint32,
303
+ torch.uint64,
304
+ ]
305
+ },
306
+ }
307
+
308
+
309
+ def deduce_output_dtype_by_name(
310
+ op_name: str,
311
+ *args,
312
+ **kwargs,
313
+ ) -> Optional[torch.dtype]:
314
+ """
315
+ Given op name and a list of input dtypes, deduce the output dtype
316
+ """
317
+ if op_name in boolean_ops():
318
+ return torch.bool
319
+ elif op_name in (
320
+ "to_dtype",
321
+ "index_expr",
322
+ ):
323
+ return kwargs["dtype"] if "dtype" in kwargs else args[-1]
324
+ elif op_name in (
325
+ "rand",
326
+ "randn",
327
+ ):
328
+ return torch.float
329
+ elif op_name in (
330
+ "get_index",
331
+ "randint64",
332
+ "load_seed",
333
+ ):
334
+ return torch.int64
335
+ elif op_name == "reduction":
336
+ return kwargs["dtype"] if "dtype" in kwargs else args[1]
337
+ elif op_name == "constant":
338
+ dtype = kwargs["dtype"] if "dtype" in kwargs else args[-1]
339
+ return DTYPE_TO_COMPUTATION_DTYPE[dtype] # type: ignore[index]
340
+ elif op_name in (
341
+ "load",
342
+ "store",
343
+ "store_reduction",
344
+ ):
345
+ buf_name = args[1]
346
+ return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
347
+ elif op_name == "to_dtype_bitcast":
348
+ return kwargs["dtype"] if "dtype" in kwargs else args[-2]
349
+ return None
350
+
351
+
352
+ class DataTypePropagation:
353
+ def __init__(self, body) -> None:
354
+ self.body = body
355
+ self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
356
+ "root": body.root_block.graph
357
+ }
358
+ for k, v in body.subblocks.items():
359
+ self.graphs[k] = v.graph
360
+
361
+ def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
362
+ inputs = node.all_input_nodes
363
+ input_nodes = [
364
+ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
365
+ ]
366
+ if len(input_nodes) == 0:
367
+ return None
368
+
369
+ all_input_nodes_propagated = all(
370
+ OptimizationContext.key in n.meta
371
+ and n.meta[OptimizationContext.key].dtype is not None
372
+ for n in input_nodes
373
+ )
374
+ if not all_input_nodes_propagated:
375
+ return None
376
+
377
+ return functools.reduce(
378
+ torch.promote_types,
379
+ [n.meta[OptimizationContext.key].dtype for n in input_nodes],
380
+ )
381
+
382
+ def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
383
+ sub_graph = self.graphs[node.target]
384
+ dtype = self.propagate_graph(sub_graph)
385
+ assert dtype
386
+ return dtype
387
+
388
+ def deduce_node_dtype(self, node: torch.fx.Node):
389
+ if node.op == "placeholder":
390
+ return None
391
+
392
+ if node.target == "output" and len(node.args) != 1:
393
+ # we can infer output node if it only have 1 arg
394
+ return None
395
+
396
+ if node.target == operator.getitem:
397
+ return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
398
+
399
+ assert isinstance(node.target, str)
400
+
401
+ if node.target.startswith("masked_subblock"):
402
+ return self.deduce_node_dtype_by_subgraph(node)
403
+
404
+ if (
405
+ output_dtype := deduce_output_dtype_by_name(
406
+ node.target,
407
+ *node.args,
408
+ **node.kwargs,
409
+ )
410
+ ) is not None:
411
+ return output_dtype
412
+
413
+ return self.deduce_node_dtype_by_inputs(node)
414
+
415
+ def propagate_graph(self, graph: torch.fx.Graph):
416
+ assert graph.nodes
417
+ graph_dtype = None
418
+ # For masked_subblock, we use output's dtype to represent
419
+ # the dtype of this subgraph. For other cases, graph_dtype
420
+ # might be None
421
+ for node in graph.nodes:
422
+ if OptimizationContext.key in node.meta:
423
+ opt_ctx = node.meta[OptimizationContext.key]
424
+ else:
425
+ opt_ctx = OptimizationContext()
426
+
427
+ opt_ctx.dtype = self.deduce_node_dtype(node)
428
+ node.meta[OptimizationContext.key] = opt_ctx
429
+ if node.target == "output":
430
+ graph_dtype = opt_ctx.dtype
431
+ return graph_dtype
432
+
433
+ def propagate(self):
434
+ self.propagate_graph(self.graphs["root"])
435
+
436
+ @classmethod
437
+ def propagate_loopbody(cls, body):
438
+ return cls(body).propagate()
439
+
440
+ @classmethod
441
+ def propagate_scheduler_node(cls, node):
442
+ from ..loop_body import LoopBody
443
+ from ..scheduler import SchedulerNode
444
+
445
+ assert isinstance(node, SchedulerNode)
446
+ assert isinstance(node._body, LoopBody)
447
+ DataTypePropagation.propagate_loopbody(node._body)
448
+
449
+
450
+ # This printer contains rules that are supposed to be generic for both C/C++ and
451
+ # Python
452
+ class ExprPrinter(Printer):
453
+ @staticmethod
454
+ def paren(string):
455
+ def all_in_parens(string):
456
+ if string[0] != "(" or len(string) < 2:
457
+ return False
458
+ count = 1
459
+ for i, char in enumerate(string[1:]):
460
+ if char == "(":
461
+ count += 1
462
+ elif char == ")":
463
+ count -= 1
464
+ if count == 0 and i != len(string) - 2:
465
+ return False
466
+ assert count == 0
467
+ return True
468
+
469
+ if (
470
+ isinstance(string, CSEVariable)
471
+ or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
472
+ or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
473
+ or string == ""
474
+ ):
475
+ return string
476
+ # don't put extra parens for strings that are already wrapped in parens
477
+ if all_in_parens(string):
478
+ return string
479
+ return f"({string})"
480
+
481
+ def _print_Relational(self, expr):
482
+ return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
483
+
484
+ def _print_Mul(self, expr):
485
+ return "*".join(map(self.paren, map(self._print, expr.args)))
486
+
487
+ def _print_Add(self, expr):
488
+ return " + ".join(map(self.paren, map(self._print, expr.args)))
489
+
490
+ # NB: this is OK to put here, because Mod is only defined for positive
491
+ # numbers, and so across C/Python its behavior is consistent
492
+ def _print_Mod(self, expr):
493
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
494
+
495
+ def _print_FloatTrueDiv(self, expr):
496
+ lhs, rhs = expr.args
497
+ return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
498
+
499
+ def _print_CleanDiv(self, expr):
500
+ return self._print_FloorDiv(expr)
501
+
502
+ def _print_Identity(self, expr):
503
+ return self._print(expr.args[0])
504
+
505
+ def _print_GreaterThan(self, expr):
506
+ # GreaterThan: >=
507
+ # StrictlyGreaterThan: >
508
+ # Go figure...
509
+ return " >= ".join(map(self.paren, map(self._print, expr.args)))
510
+
511
+ # NB: The C implementation is injected into codegen at
512
+ # torch/_inductor/codegen/wrapper.py
513
+ def _print_align(self, expr):
514
+ assert len(expr.args) == 1
515
+ return f"align({self._print(expr.args[0])})"
516
+
517
+ # This must be implemented because sympy will collect x * x into Pow(x, 2), without
518
+ # any explicit intervention. We print it just like x * x, notably, we
519
+ # never generate sympy.Pow with floats.
520
+ #
521
+ # NB: this pow by natural, you should never have used builtin sympy.pow
522
+ # for FloatPow, and a symbolic exponent should be PowByNatural. These
523
+ # means exp is guaranteed to be integer.
524
+ def _print_Pow(self, expr):
525
+ base, exp = expr.args
526
+ base = self._print(base)
527
+ assert exp == int(exp), exp
528
+ exp = int(exp)
529
+ assert exp >= 0
530
+ if exp > 0:
531
+ return "*".join([self.paren(base)] * exp)
532
+ else: # exp == 0
533
+ return "1"
534
+
535
+ # Explicit NotImplemented functions are to prevent default sympy printing
536
+ # behavior, which will just barf out ToFloat(...) to your IR. The error
537
+ # message is better here because it tells you which printer class it needs
538
+ # to go in.
539
+
540
+ def _print_ToFloat(self, expr):
541
+ raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
542
+
543
+ def _print_Infinity(self, expr):
544
+ raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
545
+
546
+ def _print_NegativeInfinity(self, expr):
547
+ raise NotImplementedError(
548
+ f"_print_NegativeInfinity not implemented for {type(self)}"
549
+ )
550
+
551
+ def _print_FloorDiv(self, expr):
552
+ raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
553
+
554
+ def _print_PythonMod(self, expr):
555
+ raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
556
+
557
+ def _print_IntTrueDiv(self, expr):
558
+ raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
559
+
560
+ def _print_PowByNatural(self, expr):
561
+ raise NotImplementedError(
562
+ f"_print_PowByNatural not implemented for {type(self)}"
563
+ )
564
+
565
+ def _print_FloatPow(self, expr):
566
+ raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
567
+
568
+ def _print_TruncToInt(self, expr):
569
+ raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
570
+
571
+ def _print_RoundToInt(self, expr):
572
+ raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
573
+
574
+ def _print_RoundDecimal(self, expr):
575
+ raise NotImplementedError(
576
+ f"_print_RoundDecimal not implemented for {type(self)}"
577
+ )
578
+
579
+ # NB: Some float operations are INTENTIONALLY not implemented for
580
+ # printers. You can implement them as a quick unblock, but it is better
581
+ # to ask yourself why we haven't done this computation in the Tensor
582
+ # universe instead
583
+
584
+ def _print_TruncToFloat(self, expr):
585
+ raise NotImplementedError(
586
+ f"_print_TruncToFloat not implemented for {type(self)}"
587
+ )
588
+
589
+ def doprint(self, expr, *, simplify: bool = True):
590
+ # TODO: why are people passing strings to the printer here :think:
591
+ if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
592
+ expr = V.graph.sizevars.simplify(expr)
593
+ return super().doprint(expr)
594
+
595
+
596
+ class PythonPrinter(ExprPrinter):
597
+ def _print_ToFloat(self, expr):
598
+ assert len(expr.args) == 1
599
+ return f"float({self._print(expr.args[0])})"
600
+
601
+ def _print_ModularIndexing(self, expr):
602
+ x, div, mod = expr.args
603
+ x = self.paren(self.doprint(x))
604
+ div = self.paren(self.doprint(div))
605
+ mod = self.paren(self.doprint(mod))
606
+ if div != "1":
607
+ x = f"({x} // {div})"
608
+ return f"{x} % {mod}"
609
+
610
+ def _print_Infinity(self, expr):
611
+ return "math.inf"
612
+
613
+ def _print_NegativeInfinity(self, expr):
614
+ return "-math.inf"
615
+
616
+ # WARNING: this is dangerous for Triton, which has C-style modulus
617
+ def _print_PythonMod(self, expr):
618
+ return " % ".join(map(self.paren, map(self._print, expr.args)))
619
+
620
+ # WARNING: this is dangerous for Triton, which has C-style modulus
621
+ def _print_FloorDiv(self, expr):
622
+ x, div = expr.args
623
+ x = self.paren(self.doprint(x))
624
+ div = self.paren(self.doprint(div))
625
+ return f"({x} // {div})"
626
+
627
+ # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
628
+ # does a special algorithm
629
+ def _print_IntTrueDiv(self, expr):
630
+ lhs, rhs = expr.args
631
+ return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
632
+
633
+ def _helper_sqrt(self, expr):
634
+ return f"math.sqrt({self._print(expr)})"
635
+
636
+ def _print_OpaqueUnaryFn_sqrt(self, expr):
637
+ return self._helper_sqrt(expr.args[0])
638
+
639
+ def _print_FloatPow(self, expr):
640
+ base, exp = expr.args
641
+ return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
642
+
643
+ # TODO: Not sure this works with Triton, even when base/exp are integral
644
+ def _print_PowByNatural(self, expr):
645
+ base, exp = expr.args
646
+ return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
647
+
648
+ def _print_floor(self, expr):
649
+ assert len(expr.args) == 1
650
+ return f"math.floor({self._print(expr.args[0])})"
651
+
652
+ def _print_FloorToInt(self, expr):
653
+ assert len(expr.args) == 1
654
+ return f"math.floor({self._print(expr.args[0])})"
655
+
656
+ def _print_TruncToInt(self, expr):
657
+ assert len(expr.args) == 1
658
+ # This also could have been int(), they'll do the same thing for float
659
+ return f"math.trunc({self._print(expr.args[0])})"
660
+
661
+ def _print_ceiling(self, expr):
662
+ assert len(expr.args) == 1
663
+ return f"math.ceil({self._print(expr.args[0])})"
664
+
665
+ def _print_CeilToInt(self, expr):
666
+ assert len(expr.args) == 1
667
+ return f"math.ceil({self._print(expr.args[0])})"
668
+
669
+ def _print_Abs(self, expr):
670
+ assert len(expr.args) == 1
671
+ return f"abs({self._print(expr.args[0])})"
672
+
673
+ # NB: It's expected that we've made explicit any promotion in the sympy
674
+ # expression, so it doesn't matter that Python max/min doesn't perform
675
+ # promotion
676
+ def _print_Max(self, expr):
677
+ assert len(expr.args) >= 2
678
+ return f"max({', '.join(map(self._print, expr.args))})"
679
+
680
+ def _print_Min(self, expr):
681
+ assert len(expr.args) >= 2
682
+ return f"min({', '.join(map(self._print, expr.args))})"
683
+
684
+ def _print_OpaqueUnaryFn_cos(self, expr):
685
+ assert len(expr.args) == 1
686
+ return f"math.cos({self._print(expr.args[0])})"
687
+
688
+ def _print_OpaqueUnaryFn_cosh(self, expr):
689
+ assert len(expr.args) == 1
690
+ return f"math.cosh({self._print(expr.args[0])})"
691
+
692
+ def _print_OpaqueUnaryFn_acos(self, expr):
693
+ assert len(expr.args) == 1
694
+ return f"math.acos({self._print(expr.args[0])})"
695
+
696
+ def _print_OpaqueUnaryFn_sin(self, expr):
697
+ assert len(expr.args) == 1
698
+ return f"math.sin({self._print(expr.args[0])})"
699
+
700
+ def _print_OpaqueUnaryFn_sinh(self, expr):
701
+ assert len(expr.args) == 1
702
+ return f"math.sinh({self._print(expr.args[0])})"
703
+
704
+ def _print_OpaqueUnaryFn_asin(self, expr):
705
+ assert len(expr.args) == 1
706
+ return f"math.asin({self._print(expr.args[0])})"
707
+
708
+ def _print_OpaqueUnaryFn_tan(self, expr):
709
+ assert len(expr.args) == 1
710
+ return f"math.tan({self._print(expr.args[0])})"
711
+
712
+ def _print_OpaqueUnaryFn_tanh(self, expr):
713
+ assert len(expr.args) == 1
714
+ return f"math.tanh({self._print(expr.args[0])})"
715
+
716
+ def _print_OpaqueUnaryFn_atan(self, expr):
717
+ assert len(expr.args) == 1
718
+ return f"math.atan({self._print(expr.args[0])})"
719
+
720
+ def _print_RoundToInt(self, expr):
721
+ assert len(expr.args) == 1
722
+ return f"round({self._print(expr.args[0])})"
723
+
724
+ def _print_RoundDecimal(self, expr):
725
+ assert len(expr.args) == 2
726
+ number, ndigits = expr.args
727
+ assert isinstance(ndigits, sympy.Integer)
728
+ return f"round({self._print(number)}, {ndigits})"
729
+
730
+
731
+ class OpOverrides:
732
+ def __init__(self, parent):
733
+ super().__init__()
734
+ self._parent = parent
735
+
736
+ def __getattr__(self, item):
737
+ return getattr(self._parent, item)
738
+
739
+ @staticmethod
740
+ def identity(value):
741
+ # used to trigger cse
742
+ return value
743
+
744
+ @staticmethod
745
+ def constant(value, dtype):
746
+ return repr(value)
747
+
748
+ @staticmethod
749
+ def reciprocal(x):
750
+ return ops.truediv(ops.constant(1, torch.int32), x)
751
+
752
+ @staticmethod
753
+ def square(x):
754
+ return ops.mul(x, x)
755
+
756
+ @staticmethod
757
+ def erfc(x):
758
+ return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
759
+
760
+ @staticmethod
761
+ def erfcx(x):
762
+ return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
763
+
764
+ @staticmethod
765
+ def expm1(x):
766
+ return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
767
+
768
+ @staticmethod
769
+ def log10(x):
770
+ return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
771
+
772
+ @staticmethod
773
+ def log2(x):
774
+ return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
775
+
776
+ @staticmethod
777
+ def exp2(x):
778
+ return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
779
+
780
+ @staticmethod
781
+ def log1p(x):
782
+ return ops.log(ops.add(x, ops.constant(1, torch.int32)))
783
+
784
+ @staticmethod
785
+ def sigmoid(x):
786
+ one = ops.constant(1, torch.int32)
787
+ return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
788
+
789
+ @staticmethod
790
+ def libdevice_sigmoid(x):
791
+ one = ops.constant(1, torch.int32)
792
+ return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
793
+
794
+ @staticmethod
795
+ def relu(x):
796
+ return ops.maximum(x, ops.constant(0, torch.int32))
797
+
798
+ @staticmethod
799
+ def libdevice_abs(x):
800
+ return ops.abs(x)
801
+
802
+ @staticmethod
803
+ def libdevice_sqrt(x):
804
+ return ops.sqrt(x)
805
+
806
+ @staticmethod
807
+ def libdevice_cos(x):
808
+ return ops.cos(x)
809
+
810
+ @staticmethod
811
+ def libdevice_sin(x):
812
+ return ops.sin(x)
813
+
814
+ @staticmethod
815
+ def libdevice_log(x):
816
+ return ops.log(x)
817
+
818
+ @staticmethod
819
+ def libdevice_exp(x):
820
+ return ops.exp(x)
821
+
822
+ @staticmethod
823
+ def bitwise_not(x):
824
+ return f"~{ExprPrinter.paren(x)}"
825
+
826
+ @staticmethod
827
+ def logical_not(a):
828
+ return f"{ExprPrinter.paren(a)} == 0"
829
+
830
+ @staticmethod
831
+ def bitwise_and(x, y):
832
+ return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
833
+
834
+ @staticmethod
835
+ def bitwise_or(x, y):
836
+ return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
837
+
838
+ @staticmethod
839
+ def bitwise_xor(x, y):
840
+ return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
841
+
842
+ @staticmethod
843
+ def bitwise_left_shift(x, y):
844
+ return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
845
+
846
+ @staticmethod
847
+ def bitwise_right_shift(x, y):
848
+ return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
849
+
850
+ @staticmethod
851
+ def remainder(a, b):
852
+ r = ops.mod(a, b)
853
+ cond = ops.and_(
854
+ ops.ne(r, ops.constant(0, torch.int32)),
855
+ ops.ne(ops.signbit(r), ops.signbit(b)),
856
+ )
857
+ return ops.where(cond, ops.add(r, b), r)
858
+
859
+ @staticmethod
860
+ def trunc_to_int(a, dtype):
861
+ return ops.to_dtype(ops.trunc(a), dtype)
862
+
863
+ @staticmethod
864
+ def floor_to_int(a, dtype):
865
+ return ops.to_dtype(ops.floor(a), dtype)
866
+
867
+ @staticmethod
868
+ def ceil_to_int(a, dtype):
869
+ return ops.to_dtype(ops.ceil(a), dtype)
870
+
871
+ @staticmethod
872
+ def round_to_int(a, dtype):
873
+ return ops.to_dtype(ops.round(a), dtype)
874
+
875
+ @staticmethod
876
+ def int_truediv(a, b):
877
+ # TODO: this is wrong
878
+ # TODO: an easy bandaid is to generate runtime asserts that it's
879
+ # <= 2**53, which is when this equation is correct
880
+ return ops.truediv(a, b)
881
+
882
+ @staticmethod
883
+ def load_seed(name, offset):
884
+ return ops.load(name, sympy.Integer(offset))
885
+
886
+ @classmethod
887
+ def _initialize_pointwise_overrides(cls, target):
888
+ assert target in {"triton", "cpp", "cppvec"}, target
889
+
890
+ for funcname, data in pointwise_overrides_data.items():
891
+ impl = getattr(data, target)
892
+ if impl is None:
893
+ continue
894
+ setattr(cls, funcname, staticmethod(impl))
895
+
896
+
897
+ @dataclasses.dataclass
898
+ class OverridesData:
899
+ name: str
900
+ cpp: Callable[..., str]
901
+ # None when not impl in libdevice/triton
902
+ triton: Optional[Callable[..., str]] = None
903
+ # None when not impl in aten/.../vec
904
+ cppvec: Optional[Callable[..., str]] = None
905
+ type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
906
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
907
+ )
908
+
909
+
910
+ # NB: if you add a new special function, don't forget to update
911
+ # torch._inductor.ops_handler too
912
+ pointwise_overrides_data: Dict[str, OverridesData] = dict(
913
+ airy_ai=OverridesData(
914
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
915
+ cpp=lambda x: f"airy_ai_forward({x})",
916
+ name="special_airy_ai",
917
+ ),
918
+ bessel_j0=OverridesData(
919
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
920
+ cpp=lambda x: f"bessel_j0_forward({x})",
921
+ triton=lambda x: f"libdevice.j0({x})",
922
+ name="special_bessel_j0",
923
+ ),
924
+ bessel_j1=OverridesData(
925
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
926
+ cpp=lambda x: f"bessel_j1_forward({x})",
927
+ triton=lambda x: f"libdevice.j1({x})",
928
+ name="special_bessel_j1",
929
+ ),
930
+ bessel_y0=OverridesData(
931
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
932
+ cpp=lambda x: f"bessel_y0_forward({x})",
933
+ triton=lambda x: f"libdevice.y0({x})",
934
+ name="special_bessel_y0",
935
+ ),
936
+ bessel_y1=OverridesData(
937
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
938
+ cpp=lambda x: f"bessel_y1_forward({x})",
939
+ triton=lambda x: f"libdevice.y1({x})",
940
+ name="special_bessel_y1",
941
+ ),
942
+ digamma=OverridesData(
943
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
944
+ cpp=lambda x: f"calc_digamma({x})",
945
+ cppvec=lambda x: f"{x}.digamma()",
946
+ name="digamma",
947
+ ),
948
+ # no cpp nor triton implementation for entr, it is defined as decomposition
949
+ # erf, erfc
950
+ erfcx=OverridesData(
951
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
952
+ cpp=lambda x: f"calc_erfcx({x})",
953
+ triton=lambda x: f"libdevice.erfcx({x})",
954
+ name="special_erfcx",
955
+ ),
956
+ fma=OverridesData(
957
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
958
+ cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
959
+ cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
960
+ triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
961
+ name="fma",
962
+ ),
963
+ # erfinv, exp2, expit, gammaln
964
+ igamma=OverridesData(
965
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
966
+ cpp=lambda x, y: f"calc_igamma({x}, {y})",
967
+ name="igamma",
968
+ ),
969
+ igammac=OverridesData(
970
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
971
+ cpp=lambda x, y: f"calc_igammac({x}, {y})",
972
+ name="igammac",
973
+ ),
974
+ gammainc=OverridesData(
975
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
976
+ cpp=lambda x, y: f"calc_igamma({x}, {y})",
977
+ name="special_gammainc",
978
+ ),
979
+ gammaincc=OverridesData(
980
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
981
+ cpp=lambda x, y: f"calc_igammac({x}, {y})",
982
+ name="special_gammaincc",
983
+ ),
984
+ i0=OverridesData(
985
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
986
+ cpp=lambda x: f"calc_i0({x})",
987
+ triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
988
+ cppvec=lambda x: f"{x}.i0()",
989
+ name="i0",
990
+ ),
991
+ i0e=OverridesData(
992
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
993
+ cpp=lambda x: f"calc_i0e({x})",
994
+ cppvec=lambda x: f"{x}.i0e()",
995
+ name="special_i0e",
996
+ ),
997
+ i1=OverridesData(
998
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
999
+ cpp=lambda x: f"calc_i1({x})",
1000
+ triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1001
+ name="special_i1",
1002
+ ),
1003
+ i1e=OverridesData(
1004
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1005
+ cpp=lambda x: f"calc_i1e({x})",
1006
+ name="special_i1e",
1007
+ ),
1008
+ log_ndtr=OverridesData(
1009
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1010
+ cpp=lambda x: f"calc_log_ndtr({x})",
1011
+ name="special_log_ndtr",
1012
+ ),
1013
+ # logit
1014
+ modified_bessel_i0=OverridesData(
1015
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1016
+ cpp=lambda x: f"modified_bessel_i0_forward({x})",
1017
+ triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
1018
+ name="special_modified_bessel_i0",
1019
+ ),
1020
+ modified_bessel_i1=OverridesData(
1021
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1022
+ cpp=lambda x: f"modified_bessel_i1_forward({x})",
1023
+ triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
1024
+ name="special_modified_bessel_i1",
1025
+ ),
1026
+ modified_bessel_k0=OverridesData(
1027
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1028
+ cpp=lambda x: f"modified_bessel_k0_forward({x})",
1029
+ name="special_modified_bessel_k0",
1030
+ ),
1031
+ modified_bessel_k1=OverridesData(
1032
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1033
+ cpp=lambda x: f"modified_bessel_k1_forward({x})",
1034
+ name="special_modified_bessel_k1",
1035
+ ),
1036
+ # multigamma
1037
+ ndtr=OverridesData(
1038
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1039
+ cpp=lambda x: f"calc_ndtr({x})",
1040
+ name="special_ndtr",
1041
+ ),
1042
+ ndtri=OverridesData(
1043
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1044
+ cpp=lambda x: f"calc_ndtri({x})",
1045
+ name="special_ndtri",
1046
+ ),
1047
+ polygamma=OverridesData(
1048
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1049
+ cpp=lambda x, y: f"calc_polygamma({y}, {x})",
1050
+ name="polygamma",
1051
+ ),
1052
+ # psi - alias to digamma
1053
+ # round
1054
+ scaled_modified_bessel_k0=OverridesData(
1055
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1056
+ cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
1057
+ name="special_scaled_modified_bessel_k0",
1058
+ ),
1059
+ scaled_modified_bessel_k1=OverridesData(
1060
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1061
+ cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
1062
+ name="special_scaled_modified_bessel_k1",
1063
+ ),
1064
+ # sinc
1065
+ spherical_bessel_j0=OverridesData(
1066
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1067
+ cpp=lambda x: f"spherical_bessel_j0_forward({x})",
1068
+ name="special_spherical_bessel_j0",
1069
+ ),
1070
+ zeta=OverridesData(
1071
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1072
+ cpp=lambda x, y: f"zeta({x}, {y})",
1073
+ name="special_zeta",
1074
+ ),
1075
+ chebyshev_polynomial_t=OverridesData(
1076
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1077
+ cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
1078
+ name="special_chebyshev_polynomial_t",
1079
+ ),
1080
+ chebyshev_polynomial_u=OverridesData(
1081
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1082
+ cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
1083
+ name="special_chebyshev_polynomial_u",
1084
+ ),
1085
+ chebyshev_polynomial_v=OverridesData(
1086
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1087
+ cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
1088
+ name="special_chebyshev_polynomial_v",
1089
+ ),
1090
+ chebyshev_polynomial_w=OverridesData(
1091
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1092
+ cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
1093
+ name="special_chebyshev_polynomial_w",
1094
+ ),
1095
+ legendre_polynomial_p=OverridesData(
1096
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1097
+ cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
1098
+ name="special_legendre_polynomial_p",
1099
+ ),
1100
+ shifted_chebyshev_polynomial_t=OverridesData(
1101
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1102
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
1103
+ name="special_shifted_chebyshev_polynomial_t",
1104
+ ),
1105
+ shifted_chebyshev_polynomial_u=OverridesData(
1106
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1107
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
1108
+ name="special_shifted_chebyshev_polynomial_u",
1109
+ ),
1110
+ shifted_chebyshev_polynomial_v=OverridesData(
1111
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1112
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_v_forward({x}, {y})",
1113
+ name="special_shifted_chebyshev_polynomial_v",
1114
+ ),
1115
+ shifted_chebyshev_polynomial_w=OverridesData(
1116
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1117
+ cpp=lambda x, y: f"shifted_chebyshev_polynomial_w_forward({x}, {y})",
1118
+ name="special_shifted_chebyshev_polynomial_w",
1119
+ ),
1120
+ hermite_polynomial_h=OverridesData(
1121
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1122
+ cpp=lambda x, y: f"hermite_polynomial_h_forward({x}, {y})",
1123
+ name="special_hermite_polynomial_h",
1124
+ ),
1125
+ hermite_polynomial_he=OverridesData(
1126
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1127
+ cpp=lambda x, y: f"hermite_polynomial_he_forward({x}, {y})",
1128
+ name="special_hermite_polynomial_he",
1129
+ ),
1130
+ laguerre_polynomial_l=OverridesData(
1131
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1132
+ cpp=lambda x, y: f"laguerre_polynomial_l_forward({x}, {y})",
1133
+ name="special_laguerre_polynomial_l",
1134
+ ),
1135
+ )
1136
+
1137
+
1138
+ # Use mypy to check protocol implemented correctly
1139
+ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
1140
+ return h
1141
+
1142
+
1143
+ class DeferredLine(DeferredLineBase):
1144
+ """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
1145
+
1146
+ def __init__(self, name, line):
1147
+ super().__init__(line)
1148
+ self.name = name
1149
+ assert not isinstance(line, DeferredLineBase)
1150
+
1151
+ def __call__(self):
1152
+ if all(
1153
+ self.name not in x
1154
+ for x in (
1155
+ V.graph.removed_buffers,
1156
+ V.kernel.removed_buffers,
1157
+ V.graph.inplaced_to_remove,
1158
+ V.kernel.inplaced_to_remove,
1159
+ )
1160
+ ):
1161
+ return self.line
1162
+ return None
1163
+
1164
+ def _new_line(self, line):
1165
+ return DeferredLine(self.name, line)
1166
+
1167
+
1168
+ class BracesBuffer(IndentedBuffer):
1169
+ def indent(self, offset=1):
1170
+ @contextlib.contextmanager
1171
+ def ctx():
1172
+ for _ in range(offset):
1173
+ self.writeline("{")
1174
+ self._indent += 1
1175
+ for _ in range(-offset):
1176
+ self._indent -= 1
1177
+ self.writeline("}")
1178
+ yield
1179
+ for _ in range(-offset):
1180
+ self.writeline("{")
1181
+ self._indent += 1
1182
+ for _ in range(offset):
1183
+ self._indent -= 1
1184
+ self.writeline("}")
1185
+
1186
+ return ctx()
1187
+
1188
+
1189
+ class InplacedBuffer(NamedTuple):
1190
+ inner_name: str
1191
+ other_names: List[str]
1192
+
1193
+
1194
+ class KernelArgs:
1195
+ @staticmethod
1196
+ def _lookup(prefix, odict, name):
1197
+ assert isinstance(name, (str, sympy.Symbol))
1198
+ if name not in odict:
1199
+ odict[name] = f"{prefix}{len(odict)}"
1200
+ return odict[name]
1201
+
1202
+ def __init__(self, sizevars=None):
1203
+ self.input_buffers = {}
1204
+ self.output_buffers = {}
1205
+ self.inplace_buffers = {}
1206
+ self.sizevars = sizevars or {}
1207
+ self.workspace_arg = None
1208
+
1209
+ def __repr__(self):
1210
+ return "KernelArgs({})".format(
1211
+ ", ".join(
1212
+ map(
1213
+ repr,
1214
+ [
1215
+ self.input_buffers,
1216
+ self.output_buffers,
1217
+ self.inplace_buffers,
1218
+ self.sizevars,
1219
+ ],
1220
+ )
1221
+ )
1222
+ )
1223
+
1224
+ def _buffer_is_marked_removed(self, name):
1225
+ return isinstance(name, str) and name.startswith("REMOVED")
1226
+
1227
+ def input(self, name):
1228
+ if V.graph.scheduler:
1229
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
1230
+ assert name not in V.graph.removed_buffers, name
1231
+ if name in self.output_buffers:
1232
+ return self.output_buffers[name]
1233
+ if name in self.inplace_buffers:
1234
+ return self.inplace_buffers[name].inner_name
1235
+ if name.startswith("seed"):
1236
+ return self._lookup("seed", self.input_buffers, name)
1237
+ return self._lookup("in_ptr", self.input_buffers, name)
1238
+
1239
+ def output(self, name):
1240
+ if V.graph.scheduler:
1241
+ name = V.graph.scheduler.mutation_real_name.get(name, name)
1242
+ assert name not in V.graph.removed_buffers, name
1243
+ if name in self.inplace_buffers:
1244
+ return self.inplace_buffers[name].inner_name
1245
+ return self._lookup("out_ptr", self.output_buffers, name)
1246
+
1247
+ def make_inplace(self, input_name, output_name):
1248
+ assert output_name not in self.inplace_buffers
1249
+ if input_name in self.inplace_buffers:
1250
+ buf = self.inplace_buffers[input_name]
1251
+ buf.other_names.append(output_name)
1252
+ self.inplace_buffers[output_name] = buf
1253
+ else:
1254
+ buf = InplacedBuffer(
1255
+ f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
1256
+ [input_name, output_name],
1257
+ )
1258
+ self.inplace_buffers[input_name] = buf
1259
+ self.inplace_buffers[output_name] = buf
1260
+
1261
+ def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
1262
+ if self.workspace_arg is None:
1263
+ self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
1264
+ return "ws_ptr", 0
1265
+
1266
+ offset = self.workspace_arg.nbytes
1267
+ zero_fill = zero_fill or self.workspace_arg.zero_fill
1268
+ self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
1269
+ return "ws_ptr", offset
1270
+
1271
+ def seed_offset(self, name, value):
1272
+ if value in self.sizevars:
1273
+ return self.sizevars[value]
1274
+ if name in self.sizevars.values():
1275
+ name = (
1276
+ f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
1277
+ )
1278
+ self.sizevars[value] = name
1279
+ return name
1280
+
1281
+ def size(self, name):
1282
+ if str(name) == "seed":
1283
+ self.sizevars["seed"] = "seed"
1284
+ return "seed"
1285
+ return self._lookup("ks", self.sizevars, name)
1286
+
1287
+ def call_names(self):
1288
+ return chain(
1289
+ self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
1290
+ )
1291
+
1292
+ def wrap_ptr_arg(self, buf, dtype):
1293
+ return buf
1294
+
1295
+ def wrap_size_arg(self, size):
1296
+ return str(size)
1297
+
1298
+ def cpp_argdefs(self):
1299
+ from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE
1300
+
1301
+ call_args = []
1302
+ arg_defs = []
1303
+ arg_types = []
1304
+ for inplaced in unique(self.inplace_buffers.values()):
1305
+ if self._buffer_is_marked_removed(inplaced):
1306
+ continue
1307
+ outer = inplaced.other_names[-1]
1308
+ inner = inplaced.inner_name
1309
+ dtype = V.graph.get_dtype(outer)
1310
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1311
+ arg_defs.append(f"{cpp_dtype}* {inner}")
1312
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1313
+ arg_types.append(f"{cpp_dtype}*")
1314
+ for outer, inner in self.input_buffers.items():
1315
+ if outer in self.inplace_buffers:
1316
+ continue
1317
+ dtype = V.graph.get_dtype(outer)
1318
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1319
+ arg_defs.append(f"const {cpp_dtype}* {inner}")
1320
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1321
+ arg_types.append(f"const {cpp_dtype}*")
1322
+ for outer, inner in self.output_buffers.items():
1323
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1324
+ continue
1325
+ dtype = V.graph.get_dtype(outer)
1326
+ cpp_dtype = DTYPE_TO_CPP[dtype]
1327
+ arg_defs.append(f"{cpp_dtype}* {inner}")
1328
+ call_args.append(self.wrap_ptr_arg(outer, dtype))
1329
+ arg_types.append(f"{cpp_dtype}*")
1330
+ for outer, inner in self.sizevars.items():
1331
+ arg_defs.append(f"const {INDEX_TYPE} {inner}")
1332
+ call_args.append(self.wrap_size_arg(outer))
1333
+ arg_types.append(f"const {INDEX_TYPE}")
1334
+ if V.graph.wrapper_code:
1335
+ V.graph.wrapper_code.ensure_size_computed(outer)
1336
+ assert self.workspace_arg is None, "Workspace not supported on CPU "
1337
+ return arg_defs, call_args, arg_types
1338
+
1339
+ def python_argdefs(self):
1340
+ arg_defs: List[str] = []
1341
+ call_args: List[str] = []
1342
+ arg_types: List[torch.dtype] = []
1343
+ precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
1344
+ for inplaced in unique(self.inplace_buffers.values()):
1345
+ if self._buffer_is_marked_removed(inplaced):
1346
+ continue
1347
+ arg_defs.append(inplaced.inner_name)
1348
+ call_args.append(inplaced.other_names[-1])
1349
+ arg_types.append(V.graph.get_dtype(inplaced.other_names[-1]))
1350
+ precompile_args.append(
1351
+ TensorArg(
1352
+ name=inplaced.inner_name,
1353
+ buffer=inplaced.other_names[-1],
1354
+ dtype=V.graph.get_dtype(inplaced.other_names[-1]),
1355
+ )
1356
+ )
1357
+ for outer, inner in chain(
1358
+ self.input_buffers.items(), self.output_buffers.items()
1359
+ ):
1360
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1361
+ continue
1362
+ arg_defs.append(inner)
1363
+ call_args.append(outer)
1364
+ arg_types.append(V.graph.get_dtype(outer))
1365
+ precompile_args.append(
1366
+ TensorArg(
1367
+ name=inner,
1368
+ buffer=outer,
1369
+ dtype=V.graph.get_dtype(outer),
1370
+ )
1371
+ )
1372
+ for outer, inner in self.sizevars.items():
1373
+ arg_defs.append(inner)
1374
+ call_args.append(outer)
1375
+ arg_types.append(type(outer)) # type: ignore[arg-type]
1376
+ precompile_args.append(SizeArg(inner, outer))
1377
+ if V.graph.wrapper_code:
1378
+ V.graph.wrapper_code.ensure_size_computed(outer)
1379
+ if self.workspace_arg is not None:
1380
+ arg_defs.append("ws_ptr")
1381
+ call_args.append("workspace")
1382
+ precompile_args.append(self.workspace_arg)
1383
+ return arg_defs, call_args, precompile_args, arg_types
1384
+
1385
+ def aliases(self):
1386
+ for inplaced in unique(self.inplace_buffers.values()):
1387
+ if self._buffer_is_marked_removed(inplaced):
1388
+ continue
1389
+ for other in inplaced.other_names:
1390
+ if (
1391
+ other in V.graph.inplaced_to_remove
1392
+ or other in V.kernel.inplaced_to_remove
1393
+ ):
1394
+ continue
1395
+ if other in self.input_buffers:
1396
+ yield self.input_buffers[other], inplaced.inner_name
1397
+ if other in self.output_buffers:
1398
+ yield self.output_buffers[other], inplaced.inner_name
1399
+
1400
+ def is_removed(self, name):
1401
+ def _is_removed(name, buffers):
1402
+ return name not in buffers or self._buffer_is_marked_removed(buffers[name])
1403
+
1404
+ return _is_removed(name, self.output_buffers) and _is_removed(
1405
+ name, self.inplace_buffers
1406
+ )
1407
+
1408
+ # Includes inplace buffers, excludes removed buffers. Essentially,
1409
+ # after you do a call into this kernel, which buffers actually contain
1410
+ # updated data? Modeled off of python_argdefs.
1411
+ def live_output_buffers(self):
1412
+ live_outs = OrderedSet() # type: ignore[var-annotated]
1413
+ for inplaced in unique(self.inplace_buffers.values()):
1414
+ if self._buffer_is_marked_removed(inplaced):
1415
+ continue
1416
+ live_outs.add(inplaced.other_names[-1])
1417
+ for outer, inner in self.output_buffers.items():
1418
+ if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
1419
+ continue
1420
+ live_outs.add(outer)
1421
+ return live_outs
1422
+
1423
+
1424
+ class CSEVariable:
1425
+ """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
1426
+ To do so, the backends can simply overload `Kernel.create_cse_var`
1427
+ The "CSEVariable.update_on_args" method gives you a hook for annotations
1428
+ See example of TritonCSEVariable in triton.py
1429
+ """
1430
+
1431
+ def __init__(self, name, bounds: ValueRanges[Any]):
1432
+ assert isinstance(bounds, ValueRanges)
1433
+ self.name = name
1434
+ self.bounds = bounds
1435
+ self.use_count = 1 # track how many tims this expression is used
1436
+
1437
+ def __str__(self):
1438
+ return self.name
1439
+
1440
+ def __hash__(self) -> int:
1441
+ return hash(self.name)
1442
+
1443
+ def __eq__(self, other) -> bool:
1444
+ return type(other) == type(self) and other.name == self.name
1445
+
1446
+ def update_on_args(self, name, args, kwargs):
1447
+ pass
1448
+
1449
+ def __repr__(self):
1450
+ return f"{self.__class__.__name__}({self.name!r})"
1451
+
1452
+
1453
+ class CppWrapperKernelArgs(KernelArgs):
1454
+ def wrap_ptr_arg(self, buf, dtype):
1455
+ from .cpp_utils import DTYPE_TO_CPP
1456
+
1457
+ if config.abi_compatible:
1458
+ # In the abi_compatible model, we just return the buf here.
1459
+ # We will form correct call args later in wrapper.generate_kernel_all.
1460
+ return buf
1461
+ else:
1462
+ return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
1463
+
1464
+ def wrap_size_arg(self, size):
1465
+ return f"{size}"
1466
+
1467
+
1468
+ class CSE:
1469
+ """Common subexpression elimination"""
1470
+
1471
+ def __init__(
1472
+ self,
1473
+ prefix="",
1474
+ suffix="",
1475
+ name_prefix="tmp",
1476
+ iter_buffers=None,
1477
+ store_cache=None,
1478
+ reduction_cache=None,
1479
+ varname_map=None,
1480
+ ):
1481
+ self.prefix = prefix
1482
+ self.suffix = suffix
1483
+ self.cache = {}
1484
+ self.name_prefix = name_prefix
1485
+ self.store_cache = store_cache or {}
1486
+ self.reduction_cache = reduction_cache or {}
1487
+ self.iter_buffer_ids = iter_buffers or itertools.count()
1488
+ self.invalidated_stores = OrderedSet() # type: ignore[var-annotated]
1489
+ self.varname_map = varname_map or {}
1490
+
1491
+ def invalidate(self, keep_vars: OrderedSet[str]):
1492
+ for name, tmp in list(self.store_cache.items()):
1493
+ if tmp not in keep_vars:
1494
+ del self.store_cache[name]
1495
+ self.invalidated_stores.add(name)
1496
+ self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
1497
+
1498
+ def clone(self):
1499
+ # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
1500
+ return CSE(
1501
+ prefix=self.prefix,
1502
+ suffix=self.suffix,
1503
+ name_prefix=self.name_prefix,
1504
+ iter_buffers=self.iter_buffer_ids,
1505
+ store_cache=self.store_cache,
1506
+ varname_map=self.varname_map,
1507
+ )
1508
+
1509
+ def generate(
1510
+ self,
1511
+ buffer: IndentedBuffer,
1512
+ expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
1513
+ *,
1514
+ bounds: ValueRanges[Any] = ValueRanges.unknown(),
1515
+ write=True,
1516
+ assignment=True,
1517
+ ) -> CSEVariable:
1518
+ if isinstance(expr, OpsValue):
1519
+ expr = expr.value
1520
+
1521
+ assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
1522
+ assert write or assignment
1523
+ if isinstance(expr, CSEVariable):
1524
+ # If the expressions were always created with all the information, we could
1525
+ # assert expr.bounds == bounds, but sometimes the expression is created
1526
+ # with the loose ValueRanges.unknown(), so we need to tighten the bounds
1527
+ expr.bounds = expr.bounds.tighten(bounds)
1528
+ expr.use_count += 1
1529
+ return expr
1530
+ cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
1531
+ var = self.cache.get(cache_key, None)
1532
+ if not var:
1533
+ var = self.newvar(bounds)
1534
+ self.cache[cache_key] = var
1535
+ if write:
1536
+ if V.kernel.current_node:
1537
+ V.kernel.current_node.codegen_originating_info(
1538
+ buffer, only_once=True
1539
+ )
1540
+ if isinstance(expr, IndentedBuffer):
1541
+ if assignment:
1542
+ buffer.writeline(f"{self.prefix}{var} =")
1543
+ buffer.splice(expr)
1544
+ buffer.writeline(self.suffix)
1545
+ else:
1546
+ if assignment:
1547
+ line = f"{self.prefix}{var} = {expr}{self.suffix}"
1548
+ else:
1549
+ line = f"{expr}{self.suffix}"
1550
+ buffer.writeline(line)
1551
+ else:
1552
+ var.bounds = var.bounds.tighten(bounds)
1553
+ var.use_count += 1
1554
+
1555
+ return var
1556
+
1557
+ def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
1558
+ var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
1559
+ var = V.kernel.create_cse_var(var_name, bounds)
1560
+ self.varname_map[var_name] = var
1561
+ return var
1562
+
1563
+
1564
+ class CodeGen:
1565
+ def __init__(self) -> None:
1566
+ super().__init__()
1567
+ self.exit_stack = contextlib.ExitStack()
1568
+
1569
+ def __enter__(self):
1570
+ self.exit_stack.__enter__()
1571
+ return self
1572
+
1573
+ def __exit__(self, exc_type, exc_val, exc_tb):
1574
+ self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
1575
+
1576
+
1577
+ class ScopedDict:
1578
+ def __init__(self, original_dict):
1579
+ self.original_dict = original_dict
1580
+ self.new_items = {}
1581
+
1582
+ def __getitem__(self, key):
1583
+ if key in self.new_items:
1584
+ return self.new_items[key]
1585
+ return self.original_dict[key]
1586
+
1587
+ def __setitem__(self, key, value):
1588
+ self.new_items[key] = value
1589
+
1590
+ def __contains__(self, key):
1591
+ return key in self.new_items or key in self.original_dict
1592
+
1593
+ def get(self, key, default=None):
1594
+ if key in self.new_items:
1595
+ return self.new_items[key]
1596
+ return self.original_dict.get(key, default)
1597
+
1598
+
1599
+ class Kernel(CodeGen):
1600
+ newvar_prefix = ""
1601
+ suffix = ""
1602
+ overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
1603
+ # TODO: these look dead, but with all the getattr it's hard to tell...
1604
+ load_format: None = None
1605
+ store_format: None = None
1606
+
1607
+ def __init__(self, args=None, increase_kernel_count=True):
1608
+ super().__init__()
1609
+ if increase_kernel_count:
1610
+ metrics.generated_kernel_count += 1
1611
+ self.args = args or KernelArgs()
1612
+ self.loads = IndentedBuffer()
1613
+ self.compute = IndentedBuffer()
1614
+ self.stores = IndentedBuffer()
1615
+
1616
+ self.num_load = 0
1617
+ self.num_reduction = 0
1618
+
1619
+ self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
1620
+ self.must_keep_buffers = OrderedSet() # type: ignore[var-annotated]
1621
+ self.store_buffer_names = OrderedSet() # type: ignore[var-annotated]
1622
+ self._load_mask = None
1623
+ self._load_other = None
1624
+ # OrderedSet in set_current_node
1625
+ self.current_node = None
1626
+ self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
1627
+
1628
+ self.removed_buffers = OrderedSet() # type: ignore[var-annotated]
1629
+ self.inplaced_to_remove = OrderedSet() # type: ignore[var-annotated]
1630
+
1631
+ # key: the buffer to write
1632
+ # value: the buffer to read and whose memory can be reused for
1633
+ # the buffer specified by key
1634
+ self.inplace_update_buffers = {}
1635
+ # Set minimum number of elements processed per thread.
1636
+ self.min_elem_per_thread = 1
1637
+ self.kernel_name = None
1638
+
1639
+ @contextlib.contextmanager
1640
+ def set_current_node(self, node):
1641
+ prior = self.current_node
1642
+ self.current_node = node
1643
+ self.node_to_bounds = node._body.bounds().get_bounds()
1644
+ try:
1645
+ yield
1646
+ finally:
1647
+ self.current_node = prior
1648
+
1649
+ @contextlib.contextmanager
1650
+ def swap_buffers(self, lb, cb=None, sb=None):
1651
+ def scope_cse(cse):
1652
+ new_cse = cse.clone()
1653
+ new_cse.cache = ScopedDict(cse.cache)
1654
+ new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
1655
+ new_cse.store_cache = ScopedDict(cse.store_cache)
1656
+ return new_cse
1657
+
1658
+ if cb is None:
1659
+ cb = lb
1660
+ loads = self.loads
1661
+ compute = self.compute
1662
+ stores = self.stores
1663
+ cse = self.cse
1664
+ self.loads = lb
1665
+ self.compute = cb
1666
+ self.stores = sb
1667
+ self.cse = scope_cse(cse)
1668
+ try:
1669
+ yield
1670
+ finally:
1671
+ self.loads = loads
1672
+ self.compute = compute
1673
+ self.stores = stores
1674
+ self.cse = cse
1675
+
1676
+ def load(self, name: str, index: sympy.Expr) -> CSEVariable:
1677
+ raise NotImplementedError
1678
+
1679
+ def indirect_load(self, name: str, index: sympy.Expr):
1680
+ """A load the depends on an index we have read"""
1681
+ prior = self.loads
1682
+ try:
1683
+ # put the load in the compute section as it might have deps
1684
+ self.loads = self.compute
1685
+ return self.load(name, index)
1686
+ finally:
1687
+ self.loads = prior
1688
+
1689
+ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
1690
+ raise NotImplementedError
1691
+
1692
+ def store(
1693
+ self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1694
+ ) -> None:
1695
+ raise NotImplementedError
1696
+
1697
+ def reduction(
1698
+ self,
1699
+ dtype: torch.dtype,
1700
+ src_dtype: torch.dtype,
1701
+ reduction_type: ReductionType,
1702
+ value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1703
+ ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1704
+ raise NotImplementedError
1705
+
1706
+ def scan(
1707
+ self,
1708
+ dtypes: Tuple[torch.dtype, ...],
1709
+ combine_fn: Callable[
1710
+ [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
1711
+ ],
1712
+ values: Tuple[CSEVariable, ...],
1713
+ ) -> Tuple[CSEVariable, ...]:
1714
+ raise NotImplementedError
1715
+
1716
+ def sort(
1717
+ self,
1718
+ dtypes: Tuple[torch.dtype, ...],
1719
+ values: Tuple[CSEVariable, ...],
1720
+ stable: bool,
1721
+ descending: bool,
1722
+ ) -> Tuple[CSEVariable, ...]:
1723
+ raise NotImplementedError
1724
+
1725
+ def var_ranges(self):
1726
+ raise NotImplementedError
1727
+
1728
+ def bucketize(
1729
+ self,
1730
+ values: CSEVariable,
1731
+ offsets_name: str,
1732
+ offsets_size: sympy.Expr,
1733
+ indexing_dtype: torch.dtype,
1734
+ right: bool,
1735
+ ) -> CSEVariable:
1736
+ """
1737
+ See [Note: Inductor bucketize op]
1738
+ """
1739
+ raise NotImplementedError
1740
+
1741
+ @property
1742
+ def assert_function(self) -> str:
1743
+ raise NotImplementedError
1744
+
1745
+ def indirect_assert(
1746
+ self,
1747
+ var: Union[CSEVariable, str],
1748
+ lower: Optional[str],
1749
+ upper: Optional[str],
1750
+ mask: Optional[Union[CSEVariable, str]] = None,
1751
+ ) -> str:
1752
+ if isinstance(var, CSEVariable):
1753
+ var = str(var)
1754
+ assert isinstance(var, str)
1755
+ assert lower is None or isinstance(lower, str)
1756
+ assert upper is None or isinstance(upper, str)
1757
+ if lower and upper:
1758
+ # The conditions need to be in parens because of Python's operator precedence.
1759
+ # It'd be less error-prone to use and/or/not, which is suported by triton
1760
+ cond = f"({lower} <= {var}) & ({var} < {upper})"
1761
+ cond_print = f"{lower} <= {var} < {upper}"
1762
+ elif lower:
1763
+ cond = f"{lower} <= {var}"
1764
+ cond_print = cond
1765
+ else:
1766
+ assert upper
1767
+ cond = f"{var} < {upper}"
1768
+ cond_print = cond
1769
+
1770
+ if mask:
1771
+ cond = f"({cond}) | ~({mask})"
1772
+
1773
+ return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
1774
+
1775
+ def check_bounds(
1776
+ self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1777
+ ):
1778
+ raise NotImplementedError
1779
+
1780
+ def index_to_str(self, index: sympy.Expr) -> str:
1781
+ raise NotImplementedError
1782
+
1783
+ def __enter__(self):
1784
+ # TODO: hoist this to top level
1785
+ class CSEProxy:
1786
+ self.name = "CSEProxy"
1787
+ vr_analysis = ValueRangeAnalysis()
1788
+
1789
+ @staticmethod
1790
+ def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
1791
+ def inner(*args, **kwargs):
1792
+ bounds = CSEProxy._bound_variable(name, *args, **kwargs)
1793
+
1794
+ value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
1795
+
1796
+ def do_cse(v):
1797
+ csevar = V.kernel.cse.generate(
1798
+ V.kernel.compute, v, bounds=bounds
1799
+ )
1800
+ csevar.update_on_args(name, args, kwargs)
1801
+ return csevar
1802
+
1803
+ return pytree.tree_map(do_cse, value)
1804
+
1805
+ return inner
1806
+
1807
+ @staticmethod
1808
+ def _bound_variable(name, *args, **kwargs):
1809
+ """
1810
+ If the variable comes from an FX node, we forward the bound we have already computed
1811
+ Else, if the variable when codegen'ing another op, we try to compute its bounds
1812
+ """
1813
+ from ..select_algorithm import TritonTemplateKernel
1814
+
1815
+ if isinstance(V.kernel, TritonTemplateKernel):
1816
+ return ValueRanges.unknown()
1817
+
1818
+ fx_node = V.interpreter.current_node
1819
+ if fx_node.target == name and self.node_to_bounds is not None:
1820
+ assert isinstance(self.node_to_bounds, dict)
1821
+ return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
1822
+ elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
1823
+ # These create lots of inner strings. We would need to compute the bounds at the ops
1824
+ # We will also likely not get much from computing VRs on these nodes
1825
+ if any(
1826
+ s in fx_node.target
1827
+ for s in ("set_indirect", "reduction", "scan")
1828
+ ):
1829
+ return ValueRanges.unknown()
1830
+
1831
+ # We assume that the inputs come from `ops.` and are not strings. If you want to generate
1832
+ # intermediary strings, wrap them in CSE variables with properly initialised bounds.
1833
+
1834
+ # If there is no FX bound but we know how to compute one we do so
1835
+ assert not kwargs
1836
+
1837
+ def arg_to_bound(x):
1838
+ if isinstance(x, CSEVariable):
1839
+ return x.bounds
1840
+ elif isinstance(x, sympy.Expr):
1841
+ return bound_sympy(x)
1842
+ else:
1843
+ return x
1844
+
1845
+ arg_bounds = list(map(arg_to_bound, args))
1846
+ return getattr(CSEProxy.vr_analysis, name)(*arg_bounds)
1847
+ else:
1848
+ return ValueRanges.unknown()
1849
+
1850
+ @staticmethod
1851
+ def indirect_indexing(
1852
+ var: CSEVariable,
1853
+ size: Union[sympy.Expr, int],
1854
+ check: bool = True,
1855
+ wrap_neg=True,
1856
+ ):
1857
+ if isinstance(size, int):
1858
+ size = sympy.Integer(size)
1859
+ assert isinstance(size, sympy.Expr), size
1860
+ # Skip CSE since this doesn't return an expression
1861
+
1862
+ if var.bounds.lower < 0: # type: ignore[operator]
1863
+ if wrap_neg:
1864
+ stm = ops.add(var, ops.index_expr(size, torch.long))
1865
+ # Mixed negative and non-negative
1866
+ if var.bounds.upper >= 0: # type: ignore[operator]
1867
+ lt = ops.lt(var, 0)
1868
+ stm = ops.where(lt, stm, var)
1869
+ else:
1870
+ stm = var
1871
+
1872
+ # Propagate bounds as we know how to compute them properly
1873
+ new_bounds = ValueRanges.unknown()
1874
+ if var.bounds != ValueRanges.unknown() and isinstance(
1875
+ size, sympy.Number
1876
+ ):
1877
+ # Take the negative part of the bound and add size to it
1878
+ # Then take union of that and the positive part
1879
+ # This is a tighter bound than that of a generic ops.where, as we have info on the cond
1880
+ neg_bounds = var.bounds & ValueRanges(-int_oo, -1)
1881
+ new_bounds = ValueRanges(
1882
+ neg_bounds.lower + size, neg_bounds.upper + size
1883
+ )
1884
+ # We don't have a good way of representing the empty range
1885
+ if var.bounds.upper >= 0: # type: ignore[operator]
1886
+ pos = var.bounds & ValueRanges(0, int_oo)
1887
+ new_bounds = new_bounds | pos
1888
+
1889
+ var = self.cse.generate(self.compute, stm, bounds=new_bounds)
1890
+
1891
+ sympy_var = parent_handler.indirect_indexing(var, size, check)
1892
+ if generate_assert(check):
1893
+ assert_lower = not (var.bounds.lower >= 0)
1894
+ # value ranges cannot x < s when x and s are symbols
1895
+ assert_upper = not isinstance(size, sympy.Number) or not (
1896
+ var.bounds.upper < size
1897
+ )
1898
+ self.check_bounds(sympy_var, size, assert_lower, assert_upper)
1899
+ return sympy_var
1900
+
1901
+ @staticmethod
1902
+ def check_bounds(
1903
+ expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
1904
+ ):
1905
+ return self.check_bounds(expr, size, lower, upper)
1906
+
1907
+ @staticmethod
1908
+ def load(name: str, index: sympy.Expr) -> CSEVariable:
1909
+ if name in self.cse.invalidated_stores:
1910
+ # A load from an invalidated store requires us to
1911
+ # keep the actual buffer around
1912
+ V.kernel.must_keep_buffers.add(name)
1913
+ if free_symbol_is_type(index, SymT.TMP):
1914
+ return self.indirect_load(name, index)
1915
+ store_cache = self.cse.store_cache
1916
+ if name in store_cache:
1917
+ return store_cache[name]
1918
+ out = self.load(name, index)
1919
+ # count load that is not in the store_cache, and also not in the
1920
+ # cse cache.
1921
+ if out.use_count == 1:
1922
+ self.num_load += 1
1923
+ return out
1924
+
1925
+ @staticmethod
1926
+ def _update_store_cache(name: str, value: CSEVariable):
1927
+ self.cse.store_cache[name] = value
1928
+ if self.current_node and name in V.graph.name_to_buffer:
1929
+ buf = self.current_node.get_output(name)
1930
+ for other_name in buf.get_mutations():
1931
+ self.cse.store_cache[other_name] = value
1932
+
1933
+ @staticmethod
1934
+ def store(
1935
+ name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1936
+ ) -> None:
1937
+ self.store_buffer_names.add(name)
1938
+ if mode is None:
1939
+ CSEProxy._update_store_cache(name, value)
1940
+ if name not in V.graph.removed_buffers:
1941
+ return self.store(name, index, value, mode=mode)
1942
+ else:
1943
+ return None # type: ignore[return-value]
1944
+
1945
+ @staticmethod
1946
+ def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
1947
+ self.store_buffer_names.add(name)
1948
+ CSEProxy._update_store_cache(name, value)
1949
+
1950
+ if name not in V.graph.removed_buffers:
1951
+ return self.store_reduction(name, index, value)
1952
+
1953
+ @staticmethod
1954
+ def reduction(
1955
+ dtype: torch.dtype,
1956
+ src_dtype: torch.dtype,
1957
+ reduction_type: ReductionType,
1958
+ value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1959
+ ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1960
+ self.num_reduction += 1
1961
+ return self.reduction(dtype, src_dtype, reduction_type, value)
1962
+
1963
+ @staticmethod
1964
+ def scan(
1965
+ dtypes: Tuple[torch.dtype, ...],
1966
+ combine_fn: Callable[
1967
+ [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]],
1968
+ Tuple[CSEVariable, ...],
1969
+ ],
1970
+ values: Tuple[CSEVariable, ...],
1971
+ ) -> Tuple[CSEVariable, ...]:
1972
+ return self.scan(dtypes, combine_fn, values)
1973
+
1974
+ @staticmethod
1975
+ def sort(
1976
+ dtypes: Tuple[torch.dtype, ...],
1977
+ values: Tuple[CSEVariable, ...],
1978
+ stable: bool,
1979
+ descending: bool,
1980
+ ) -> Tuple[CSEVariable, ...]:
1981
+ return self.sort(dtypes, values, stable, descending)
1982
+
1983
+ @staticmethod
1984
+ def bucketize(
1985
+ values: CSEVariable,
1986
+ offsets_name: str,
1987
+ offsets_size: sympy.Expr,
1988
+ indexing_dtype: torch.dtype,
1989
+ right: bool,
1990
+ ) -> CSEVariable:
1991
+ """
1992
+ [Note: Inductor bucketize op]
1993
+
1994
+ Given values (tensor) and offsets_name (reference to the name of a 1D
1995
+ tensor), calculate the bucket that each value belongs to.
1996
+
1997
+ e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
1998
+ return = [ 0, 1, 1, 1, 1, 3, 3, 4].
1999
+
2000
+ When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
2001
+ When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
2002
+
2003
+ Offsets must be non-decreasing or the result is undefined.
2004
+ """
2005
+ return self.bucketize(
2006
+ values, offsets_name, offsets_size, indexing_dtype, right
2007
+ )
2008
+
2009
+ # Use mypy to check protocol implemented correctly
2010
+ def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
2011
+ return h
2012
+
2013
+ super().__enter__()
2014
+ assert self.overrides
2015
+ parent_handler = self.overrides(V.get_ops_handler())
2016
+ self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
2017
+ self.exit_stack.enter_context(V.set_kernel_handler(self))
2018
+ return self
2019
+
2020
+ def __exit__(self, exc_type, exc_val, exc_tb):
2021
+ """
2022
+ Note that V.graph.scheduler can be None when codegening triton template
2023
+ kernels.
2024
+ """
2025
+ if V.graph.scheduler:
2026
+ V.graph.scheduler.remove_kernel_local_buffers()
2027
+ super().__exit__(exc_type, exc_val, exc_tb)
2028
+
2029
+ def rename_indexing(self, index) -> sympy.Expr:
2030
+ # adds the necessary kernel args for index expressions
2031
+ # and renames variables in index expressions to kernel arg names
2032
+ if isinstance(index, (list, tuple)):
2033
+ return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
2034
+ index = V.graph.sizevars.simplify(index)
2035
+ sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
2036
+ replacements = {
2037
+ x: self.args.size(x)
2038
+ for x in sorted_symbols
2039
+ if symbol_is_type(
2040
+ x,
2041
+ (
2042
+ SymT.UNBACKED_INT,
2043
+ SymT.SIZE,
2044
+ SymT.PRECOMPUTED_SIZE,
2045
+ ),
2046
+ )
2047
+ }
2048
+ return sympy_subs(index, replacements)
2049
+
2050
+ def create_cse_var(self, *args, **kwargs):
2051
+ return CSEVariable(*args, **kwargs)
2052
+
2053
+
2054
+ @dataclasses.dataclass
2055
+ class OptimizationContext:
2056
+ key: ClassVar[str] = "opt_ctx"
2057
+
2058
+ dtype: Optional[torch.dtype] = None
2059
+ ops_name: str = ""
2060
+
2061
+
2062
+ @functools.lru_cache(None)
2063
+ def jinja2_env():
2064
+ try:
2065
+ import jinja2
2066
+
2067
+ return jinja2.Environment(
2068
+ undefined=jinja2.StrictUndefined,
2069
+ )
2070
+ except ImportError:
2071
+ return None
2072
+
2073
+
2074
+ class KernelTemplate:
2075
+ """
2076
+ Base class for defining kernel templates.
2077
+
2078
+ Children classes: TritonTemplate, CUDATemplate
2079
+ """
2080
+
2081
+ @staticmethod
2082
+ def indent_except_first(source: str, num_indents: int, indents_spacing=4):
2083
+ lines = source.splitlines(True)
2084
+ if len(lines) > 1:
2085
+ lines[1:] = [
2086
+ (" " * indents_spacing * num_indents) + line for line in lines[1:]
2087
+ ]
2088
+ return "".join(lines)
2089
+
2090
+ @staticmethod
2091
+ def _template_from_string(source):
2092
+ env = jinja2_env()
2093
+ if env is not None:
2094
+ env.filters["indent_except_first"] = KernelTemplate.indent_except_first
2095
+ from jinja2 import TemplateSyntaxError
2096
+
2097
+ class DetailedTemplateSyntaxError(TemplateSyntaxError):
2098
+ def __init__(self, original_error):
2099
+ super().__init__(
2100
+ original_error.message,
2101
+ original_error.lineno,
2102
+ original_error.name,
2103
+ original_error.filename,
2104
+ )
2105
+ self.original_error = original_error
2106
+
2107
+ def __str__(self):
2108
+ error_info = f"Error in template at line {self.lineno}\n"
2109
+ error_info += f"Error message: {self.message}\n"
2110
+ if hasattr(self.original_error, "source"):
2111
+ lines = self.original_error.source.split("\n")
2112
+ error_info += "Context:\n"
2113
+ start = max(0, self.lineno - 2)
2114
+ end = min(len(lines), self.lineno + 2)
2115
+ for i in range(start, end):
2116
+ if i == self.lineno - 1:
2117
+ error_info += f"{i+1}: --> {lines[i]}\n"
2118
+ if hasattr(self.original_error, "column"):
2119
+ error_info += (
2120
+ " "
2121
+ + " " * (self.original_error.column - 1)
2122
+ + "^\n"
2123
+ )
2124
+ else:
2125
+ error_info += f"{i+1}: {lines[i]}\n"
2126
+ return error_info
2127
+
2128
+ try:
2129
+ return env.from_string(source)
2130
+ except TemplateSyntaxError as e:
2131
+ raise DetailedTemplateSyntaxError(e) from e
2132
+
2133
+ return None
2134
+
2135
+ @staticmethod
2136
+ def _fake_get_dtype(fake_out):
2137
+ _get_dtype_real = V.graph.get_dtype
2138
+
2139
+ def get_dtype(name):
2140
+ if name == fake_out.get_name():
2141
+ return fake_out.get_dtype()
2142
+ return _get_dtype_real(name)
2143
+
2144
+ return get_dtype
2145
+
2146
+ def __init__(self, name: str):
2147
+ self.name = name
2148
+
2149
+ def maybe_append_choice(self, choices, **kwargs):
2150
+ """
2151
+ Maybe generates a new ChoiceCaller and appends it into existing choices.
2152
+
2153
+ choices: A list of ChoiceCallers.
2154
+ kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
2155
+ """
2156
+
2157
+ try:
2158
+ choices.append(self.generate(**kwargs))
2159
+ except NotImplementedError as e:
2160
+ pass
2161
+
2162
+ def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
2163
+ """
2164
+ Generates a ChoiceCaller instance from the given arguments.
2165
+ """
2166
+
2167
+ raise NotImplementedError
.venv/Lib/site-packages/torch/_inductor/codegen/cpp.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/Lib/site-packages/torch/_inductor/codegen/cpp_gemm_template.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import logging
4
+ import math
5
+ from functools import lru_cache
6
+ from typing import Any, Callable, cast, List, Optional, Set, Union
7
+ from unittest.mock import patch
8
+
9
+ import torch
10
+ import torch.utils
11
+
12
+ from ..._dynamo.utils import counters
13
+ from .. import config, ir, lowering as L
14
+ from ..kernel.mm_common import mm_args
15
+ from ..select_algorithm import DataProcessorTemplateWrapper
16
+ from ..utils import cache_on_self, has_free_symbols, parallel_num_threads
17
+ from ..virtualized import ops, V
18
+ from .cpp import get_export_declaration
19
+ from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType
20
+ from .cpp_template import CppTemplate
21
+ from .cpp_template_kernel import CppTemplateKernel
22
+ from .cpp_utils import (
23
+ create_epilogue_with_attr,
24
+ DTYPE_TO_CPP,
25
+ GemmBlocking,
26
+ get_gemm_template_output_and_compute_dtype,
27
+ )
28
+
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+ GEMM_TEMPLATE = r"""
33
+ {{template.header().getvalue()}}
34
+
35
+ {{micro_gemm.codegen_define(kernel)}}
36
+
37
+ {%- if x_scale is not none %}
38
+ {%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
39
+ {%- else %}
40
+ {%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
41
+ {%- endif %}
42
+
43
+ extern "C" {{export_declaration}}
44
+ {{kernel.def_kernel(inputs=kernel_args, outputs={"Y": Y}, aliases=aliases)}}
45
+ {
46
+ {{kernel.maybe_codegen_profile()}}
47
+ constexpr int64_t num_threads = {{num_threads}};
48
+ constexpr int64_t N = {{N}};
49
+ constexpr int64_t K = {{K}};
50
+ constexpr int64_t Mr = {{micro_gemm.register_blocking.block_m}};
51
+ constexpr int64_t Nr = {{micro_gemm.register_blocking.block_n}};
52
+ constexpr int64_t Kr = {{micro_gemm.register_blocking.block_k}};
53
+ constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
54
+ constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
55
+
56
+ {%- if is_dynamic_M %}
57
+ const int64_t M = {{kernel.size(GemmOut, 0)}};
58
+ const int64_t Mr_blocks = (M + Mr - 1) / Mr;
59
+ {%- if num_threads > 1 %}
60
+ int64_t Mt_blocks, Nt_blocks, Kt_blocks;
61
+ mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks);
62
+ {%- else %}
63
+ const auto Mt_blocks = Mr_blocks;
64
+ const auto Nt_blocks = Nr_blocks;
65
+ const auto Kt_blocks = Kr_blocks;
66
+ {%- endif %}
67
+ int64_t Mc_blocks, Nc_blocks, Kc_blocks;
68
+ uint32_t L1_cache_size = {{L1_cache_size}};
69
+ uint32_t L2_cache_size = {{L2_cache_size}};
70
+ mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>(
71
+ num_threads,
72
+ M,
73
+ N,
74
+ K,
75
+ Mr,
76
+ Nr,
77
+ Kr,
78
+ Mt_blocks,
79
+ Nt_blocks,
80
+ Kt_blocks,
81
+ Mc_blocks,
82
+ Nc_blocks,
83
+ Kc_blocks,
84
+ L1_cache_size,
85
+ L2_cache_size
86
+ );
87
+ const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
88
+ const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
89
+ const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
90
+ {%- else %}
91
+ constexpr int64_t M = {{kernel.size(GemmOut, 0)}};
92
+ constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
93
+ constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}};
94
+ constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}};
95
+ constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}};
96
+ constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}};
97
+ constexpr int64_t Nc_blocks = {{template.cache_blocking().block_n}};
98
+ constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}};
99
+ constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
100
+ constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
101
+ constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
102
+ {%- endif %}
103
+
104
+ // make sure all partitions are assigned
105
+ {{kernel.assert_function}}(
106
+ Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= Mr_blocks * Nr_blocks * Kr_blocks,
107
+ "Not all partitions are assigned."
108
+ );
109
+
110
+ {%- if maybe_k_slicing %}
111
+ std::unique_ptr<std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[]> local_buf_ptrs;
112
+ if (num_k_slices > 1) {
113
+ local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]);
114
+ }
115
+ {%- endif %}
116
+
117
+ {%- if num_threads > 1 %}
118
+ #pragma omp parallel num_threads({{num_threads}})
119
+ {
120
+ const int tid = omp_get_thread_num();
121
+ int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end;
122
+ mm_get_thread_blocks(
123
+ tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks,
124
+ m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end);
125
+ {%- if maybe_k_slicing %}
126
+ const int64_t k_group_id = tid / num_k_slices;
127
+ const int64_t k_slice_id = tid % num_k_slices;
128
+ {%- endif %}
129
+ {%- else %}
130
+ {
131
+ const int tid = 0;
132
+ const int64_t m_block_start = 0;
133
+ const int64_t m_block_end = Mr_blocks;
134
+ const int64_t n_block_start = 0;
135
+ const int64_t n_block_end = Nr_blocks;
136
+ const int64_t k_block_start = 0;
137
+ const int64_t k_block_end = Kr_blocks;
138
+ {%- endif %}
139
+ {{ micro_gemm.codegen_init(kernel) }}
140
+ {%- if use_local_acc %}
141
+ {%- set acc_buf_name = "local_acc_buf" %}
142
+ {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
143
+ {%- endif %}
144
+ for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
145
+ const int64_t m_start = mc * Mr;
146
+ const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
147
+ const int64_t m_size = m_end - m_start;
148
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
149
+ const int64_t n_start = nc * Nr;
150
+ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
151
+ const int64_t n_size = n_end - n_start;
152
+ // NB: assume we pad N, nc_block_end won't exceed padded N here.
153
+ const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
154
+ {%- if use_local_acc %}
155
+ {%- set acc = kernel.local_buffers[acc_buf_name] %}
156
+ {{ kernel.reinit_buffer_if_null(acc_buf_name) }}
157
+ {%- else %}
158
+ {%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_end")]) %}
159
+ {%- endif %}
160
+ for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
161
+ int64_t k_start = kc * Kr;
162
+ int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
163
+ {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %}
164
+ for (int64_t nci = nc; nci < nc_block_end; nci++) {
165
+ {%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %}
166
+ {%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
167
+ {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
168
+ if (kc == k_block_start) {
169
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
170
+ } else {
171
+ {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
172
+ }
173
+ }
174
+ }
175
+ {%- if maybe_k_slicing %}
176
+ if (num_k_slices > 1) {
177
+ const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
178
+ local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }});
179
+ } else
180
+ {%- endif %}
181
+ {
182
+ {%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %}
183
+ {%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %}
184
+ {{ kernel.store_output(
185
+ tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
186
+ )|indent(20, false)
187
+ }}
188
+ }
189
+ }
190
+ }
191
+ {%- if maybe_k_slicing %}
192
+ if (num_k_slices > 1) {
193
+ #pragma omp barrier
194
+ for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) {
195
+ // We slice M-dim and each thread in the k-slicing group works on a slice
196
+ const int64_t m_start_unsliced = mc * Mr;
197
+ const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
198
+ const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced;
199
+ const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices;
200
+ const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced);
201
+ const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced);
202
+ const int64_t m_size = m_end - m_start;
203
+ const int64_t m_offset = m_start - m_start_unsliced;
204
+ for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
205
+ const int64_t n_start = nc * Nr;
206
+ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
207
+ const int64_t n_size = n_end - n_start;
208
+ const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc;
209
+ auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get();
210
+ for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) {
211
+ auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get();
212
+ for (int64_t m = m_offset; m < m_offset + m_size; m++) {
213
+ #pragma omp simd
214
+ for (int64_t n = 0; n < n_size; n++) {
215
+ {{acc_buf_name}}[m*Nr + n] += other_acc[m*Nr + n];
216
+ }
217
+ }
218
+ }
219
+ {%- set tile_acc_m_slice = kernel.slice_nd(tile_acc, [("m_offset", "m_offset + m_end - m_start"), ()]) %}
220
+ {{ kernel.store_output(
221
+ tile_Y, tile_acc_m_slice, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
222
+ )|indent(20, false)
223
+ }}
224
+ }
225
+ }
226
+ }
227
+ {%- endif %}
228
+ {{ micro_gemm.codegen_finalize(kernel) }}
229
+ }
230
+ }
231
+ """
232
+
233
+
234
+ def get_padded_n(n, block_n):
235
+ return (n + block_n - 1) // block_n * block_n
236
+
237
+
238
+ class CppPackedGemmTemplate(CppTemplate):
239
+ def __init__(
240
+ self,
241
+ input_nodes,
242
+ layout: ir.Layout,
243
+ num_threads: int,
244
+ register_blocking: GemmBlocking,
245
+ beta=1,
246
+ alpha=1,
247
+ has_bias=False,
248
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
249
+ ) -> None:
250
+ assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
251
+ super().__init__(
252
+ "packed_gemm",
253
+ input_nodes,
254
+ layout,
255
+ num_threads,
256
+ epilogue_creator=epilogue_creator,
257
+ )
258
+ self.beta = beta
259
+ self.alpha = alpha
260
+ self.has_bias = has_bias
261
+ self.register_blocking = register_blocking
262
+ m, n = layout.size
263
+ _, k = input_nodes[0].get_size()
264
+ self.m, self.n, self.k = m, n, k
265
+ self.padded_n = get_padded_n(n, self.register_blocking.block_n)
266
+ self.is_dynamic_M = has_free_symbols((m,))
267
+
268
+ @cache_on_self
269
+ def thread_blocking(self) -> GemmBlocking:
270
+ """
271
+ NOTE [Thread blocking in Cpp GEMM]
272
+ We use simple heuristics to decide the thread blocking:
273
+ 1. Make sure all threads are occupied as much as possible.
274
+ 2. For (m, n) blocks, favor more square-sized thread blocks for better data reuse.
275
+ 3. If (m, n) blocks cannot occupy all the threads, we consider k-slicing.
276
+ TODO(jgong5): allow tuning various blocking options
277
+ """
278
+
279
+ @lru_cache(maxsize=100)
280
+ def get_factors(number):
281
+ factors = []
282
+ for i in range(int(number**0.5), 0, -1):
283
+ if number % i == 0:
284
+ factors.append(number // i)
285
+ factors.append(i)
286
+ return factors
287
+
288
+ def get_blocking(m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks):
289
+ thread_block_k = math.ceil(k_blocks / k_factor)
290
+ thread_block_n = math.ceil(n_blocks / n_factor)
291
+ thread_block_m = math.ceil(m_blocks / m_factor)
292
+ return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
293
+
294
+ assert (
295
+ not self.is_dynamic_M
296
+ ), "Unable to determine thread blocking for dynamic M."
297
+ register_blocking = self.register_blocking
298
+ m_blocks = math.ceil(self.m / register_blocking.block_m)
299
+ n_blocks = math.ceil(self.n / register_blocking.block_n)
300
+ k_blocks = math.ceil(self.k / register_blocking.block_k)
301
+ factors = get_factors(self.num_threads)
302
+ assert len(factors) > 0
303
+
304
+ if config.cpp.gemm_thread_factors is not None:
305
+ factors = [int(i) for i in config.cpp.gemm_thread_factors.split(",")]
306
+ assert len(factors) == 3
307
+ assert math.prod(factors) == self.num_threads
308
+ return get_blocking(
309
+ factors[0], factors[1], factors[2], m_blocks, n_blocks, k_blocks
310
+ )
311
+
312
+ # we favor square-sized thread blocks for good data reuse
313
+ def get_better_blocking(blocking, best_blocking):
314
+ if best_blocking is None:
315
+ best_blocking = blocking
316
+ else:
317
+ block_m_size = blocking.block_m * register_blocking.block_m
318
+ block_n_size = blocking.block_n * register_blocking.block_n
319
+ best_block_m_size = best_blocking.block_m * register_blocking.block_m
320
+ best_block_n_size = best_blocking.block_n * register_blocking.block_n
321
+ if blocking.block_k > best_blocking.block_k:
322
+ best_blocking = blocking
323
+ elif (
324
+ blocking.block_k == best_blocking.block_k
325
+ and block_m_size + block_n_size
326
+ < best_block_m_size + best_block_n_size
327
+ ):
328
+ best_blocking = blocking
329
+ return best_blocking
330
+
331
+ best_blocking = None
332
+ # check if we can have a thread-blocking to occupy all threads without k-slicing
333
+ for n_factor in factors:
334
+ m_factor = self.num_threads // n_factor
335
+ if n_blocks >= n_factor and m_blocks >= m_factor:
336
+ blocking = get_blocking(
337
+ m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
338
+ )
339
+ best_blocking = get_better_blocking(blocking, best_blocking)
340
+
341
+ if best_blocking is None:
342
+ for k_factor in factors:
343
+ if k_blocks >= k_factor and (
344
+ config.cpp.gemm_max_k_slices == 0
345
+ or k_factor <= config.cpp.gemm_max_k_slices
346
+ ):
347
+ n_factors = get_factors(self.num_threads // k_factor)
348
+ for n_factor in n_factors:
349
+ m_factor = (self.num_threads // k_factor) // n_factor
350
+ if n_blocks >= n_factor and m_blocks >= m_factor:
351
+ blocking = get_blocking(
352
+ m_factor,
353
+ n_factor,
354
+ k_factor,
355
+ m_blocks,
356
+ n_blocks,
357
+ k_blocks,
358
+ )
359
+ best_blocking = get_better_blocking(blocking, best_blocking)
360
+
361
+ if best_blocking is None:
362
+ for n_factor in factors:
363
+ m_factor = self.num_threads // n_factor
364
+ if n_blocks >= n_factor or m_blocks >= m_factor:
365
+ blocking = get_blocking(
366
+ m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks
367
+ )
368
+ best_blocking = get_better_blocking(blocking, best_blocking)
369
+
370
+ assert best_blocking is not None
371
+ return best_blocking
372
+
373
+ @cache_on_self
374
+ def cache_blocking(self) -> GemmBlocking:
375
+ def get_cache_blocking(register_blocking, thread_blocking):
376
+ Mr = register_blocking.block_m
377
+ Nr = register_blocking.block_n
378
+ Kr = register_blocking.block_k
379
+
380
+ Mt_blocks = thread_blocking.block_m
381
+ Nt_blocks = thread_blocking.block_n
382
+ Kt_blocks = thread_blocking.block_k
383
+
384
+ if config.cpp.gemm_cache_blocking is not None:
385
+ blockings = [int(i) for i in config.cpp.gemm_cache_blocking.split(",")]
386
+ assert len(blockings) == 3
387
+ Mc_blocks, Nc_blocks, Kc_blocks = blockings
388
+ return (
389
+ min(Mc_blocks, Mt_blocks),
390
+ min(Nc_blocks, Nt_blocks),
391
+ min(Kc_blocks, Kt_blocks),
392
+ )
393
+
394
+ # The ratios below are empirically determined to decide
395
+ # the effective sizes of L1 and L2.
396
+ # TODO: tune the factor here
397
+ L1_limit_factor = 0.8
398
+ L2_limit_factor = 0.5
399
+
400
+ L1_cache_size = (
401
+ torch._C._cpu._L1d_cache_size()
402
+ ) # per core cache size in Bytes
403
+ assert (
404
+ L1_cache_size > 0
405
+ ), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
406
+ L1 = L1_cache_size * L1_limit_factor
407
+
408
+ L2_cache_size = (
409
+ torch._C._cpu._L2_cache_size()
410
+ ) # per core cache size in Bytes
411
+ assert (
412
+ L2_cache_size > 0
413
+ ), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
414
+ L2 = L2_cache_size * L2_limit_factor
415
+
416
+ def get_num_byte(dtype):
417
+ return torch.tensor([], dtype=dtype).element_size()
418
+
419
+ num_byte_A = get_num_byte(self.input_nodes[0].get_dtype())
420
+ num_byte_B = get_num_byte(self.input_nodes[1].get_dtype())
421
+
422
+ # NOTE [CPP GEMM Cache Blocking Algorithm]
423
+ # Our overall strategy is to
424
+ # 1) Make cache blocks of B L1-reside and reused by multiple rows of A, i.e. Mc.
425
+ # Here, B is Kc x Nr where Nr is a single register block. We use L1 size to
426
+ # decide Kc. We want to make Mc large enough to better reuse B.
427
+ # 2) Make cache blocks of A L2-reside, which would limit Mc. We want to reuse A
428
+ # along N, where we have two sub-strategies (see notes below) to decide Mc and Nc.
429
+
430
+ # Step 1: Decide Kc assuming B block is L1-reside.
431
+ size_cache_B = Kr * Kt_blocks * Nr * num_byte_B
432
+ Kc_blocks = Kt_blocks
433
+ if size_cache_B > L1:
434
+ Kc_blocks = math.floor(L1 / (Kr * Nr * num_byte_B))
435
+
436
+ # Step 2: Decide Mc assuming A block is L2-reside.
437
+ min_Mc_ratio = 2 # TODO(jgong5): something to tune?
438
+ min_Mc_blocks = math.ceil(min_Mc_ratio * Mr / Nr)
439
+ assert min_Mc_blocks >= 1
440
+ Kt_bytes = Kt_blocks * Kr * num_byte_A
441
+ if min_Mc_blocks * Mr * Kt_bytes < L2:
442
+ # Strategy 1: A (Mc x Kt) resides in L2 and reused by all Nt
443
+ # when Nc_blocks is kept 1. Mc should be large enough (>= min_Mc_blocks)
444
+ # to reuse B (Kc x Nr) in L1. This makes C (Mc x Nr) small enough to reside
445
+ # in L1.
446
+ Mc_blocks = min(Mt_blocks, math.floor(L2 / (Mr * Kt_bytes)))
447
+ Nc_blocks = 1
448
+ else:
449
+ # Strategy 2: Kt is too large to hold A (Mc x Kt) in L2, we reuse
450
+ # A (Mc x Kc) in L2 by B (Kc x Nc). C (Mc x Nc) resides in L2.
451
+ Mc_blocks = Mt_blocks
452
+ Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
453
+ Nc_bytes = Nc_blocks * Nr * 4 # assume C or acc is float32/int32
454
+ Kc_bytes = Kc_blocks * Kr * num_byte_A
455
+ if Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2:
456
+ # The following is the solution for 4*Mc*Nc + Mc*Kc_bytes = L2,
457
+ # assuming Mc == Nc for good data reuse.
458
+ M_max = (math.sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8
459
+ if M_max < Mc_blocks * Mr:
460
+ Mc_blocks = math.floor(M_max / Mr)
461
+ Nc_blocks = min(math.ceil(Mc_blocks * Mr / Nr), Nt_blocks)
462
+
463
+ return Mc_blocks, Nc_blocks, Kc_blocks
464
+
465
+ assert (
466
+ not self.is_dynamic_M
467
+ ), "Unable to determine cache blocking for dynamic M."
468
+ register_blocking = self.register_blocking
469
+ thread_blocking = self.thread_blocking()
470
+
471
+ return GemmBlocking(*get_cache_blocking(register_blocking, thread_blocking))
472
+
473
+ def log_blockings(self):
474
+ log.debug(f"Register blocking: {self.register_blocking}") # noqa: G004
475
+ if self.is_dynamic_M:
476
+ # thread and cache blockings are determined at runtime for dynamic shapes
477
+ return
478
+ log.debug(f"Cache blocking: {self.cache_blocking()}") # noqa: G004
479
+ thread_blocking = self.thread_blocking()
480
+ log.debug(f"Thread blocking: {thread_blocking}") # noqa: G004
481
+
482
+ def get_occupancy():
483
+ m_blocks = math.ceil(self.m / self.register_blocking.block_m)
484
+ n_blocks = math.ceil(self.n / self.register_blocking.block_n)
485
+ k_blocks = math.ceil(self.k / self.register_blocking.block_k)
486
+ m = math.ceil(m_blocks / thread_blocking.block_m)
487
+ n = math.ceil(n_blocks / thread_blocking.block_n)
488
+ k = math.ceil(k_blocks / thread_blocking.block_k)
489
+ return (m, n, k)
490
+
491
+ log.debug(
492
+ f"Number of threads: {self.num_threads}, occupancy: {get_occupancy()}" # noqa: G004
493
+ )
494
+
495
+ def maybe_k_slicing(self):
496
+ if self.num_threads == 1:
497
+ return False
498
+ if self.is_dynamic_M:
499
+ # TODO(jgong5): perhaps use size hint to decide?
500
+ return True
501
+ register_blocking = self.register_blocking
502
+ k_blocks = math.ceil(self.k / register_blocking.block_k)
503
+ thread_blocking = self.thread_blocking()
504
+ return k_blocks > thread_blocking.block_k
505
+
506
+ @staticmethod
507
+ def add_choices(
508
+ choices,
509
+ layout,
510
+ input_nodes,
511
+ beta=1,
512
+ alpha=1,
513
+ has_bias=False,
514
+ trans_w=False,
515
+ input_indices=None,
516
+ epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
517
+ ):
518
+ if input_indices is None:
519
+ input_indices = list(range(len(input_nodes)))
520
+
521
+ def reorder_and_filter(inputs, layout_or_out):
522
+ if has_bias:
523
+ assert len(input_indices) >= 3
524
+ # Assume the input order is [inp, x, w] and we reorder it to [x, w, inp]
525
+ inp_idx = input_indices[0]
526
+ x_idx = input_indices[1]
527
+ w_idx = input_indices[2]
528
+ return [
529
+ inputs[x_idx],
530
+ inputs[w_idx],
531
+ inputs[inp_idx],
532
+ *[inputs[idx] for idx in input_indices[3:]],
533
+ ], layout_or_out
534
+ else:
535
+ assert len(input_indices) >= 2
536
+ return [inputs[idx] for idx in input_indices], layout_or_out
537
+
538
+ def maybe_to_dense(inputs, layout_or_out):
539
+ new_inputs = list(inputs)
540
+ if isinstance(inputs[1], torch.Tensor):
541
+ W = inputs[1]
542
+ new_inputs[1] = W.to_dense() if W.is_mkldnn else W
543
+ return new_inputs, layout_or_out
544
+
545
+ def normalize_shapes(inputs, layout_or_out):
546
+ if not trans_w:
547
+ return inputs, layout_or_out
548
+ new_inputs = list(inputs)
549
+ X = inputs[0]
550
+ W = inputs[1]
551
+ B = inputs[2] if has_bias else None
552
+ if isinstance(W, ir.IRNode):
553
+ if trans_w:
554
+ if not isinstance(W, ir.TensorBox):
555
+ W = ir.TensorBox(W)
556
+ W = L.permute(W, [1, 0])
557
+ else:
558
+ if trans_w:
559
+ assert isinstance(W, torch.Tensor)
560
+ W = W.transpose(0, 1)
561
+ if B is not None:
562
+ if isinstance(B, ir.IRNode):
563
+ if not isinstance(B, ir.TensorBox):
564
+ B = ir.TensorBox(B)
565
+ B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
566
+ else:
567
+ assert isinstance(B, torch.Tensor)
568
+ B = B.expand(X.shape[0], B.shape[-1])
569
+ new_inputs[1] = W
570
+ if B is not None:
571
+ new_inputs[2] = B
572
+ return new_inputs, layout_or_out
573
+
574
+ # TODO(jgong5): decide proper number of threads per problem size
575
+ num_threads = parallel_num_threads()
576
+ new_inputs, _ = normalize_shapes(
577
+ *maybe_to_dense(*reorder_and_filter(input_nodes, layout))
578
+ )
579
+ m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
580
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
581
+ new_inputs[0].get_dtype()
582
+ )
583
+ micro_gemm = create_micro_gemm(
584
+ "micro_gemm",
585
+ m,
586
+ n,
587
+ k,
588
+ input_dtype=new_inputs[0].get_dtype(),
589
+ input2_dtype=new_inputs[1].get_dtype(),
590
+ output_dtype=output_dtype,
591
+ compute_dtype=compute_dtype,
592
+ alpha=alpha,
593
+ num_threads=num_threads,
594
+ )
595
+ assert micro_gemm is not None
596
+ _, block_n, _ = micro_gemm.register_blocking
597
+ padded_n = get_padded_n(n, block_n)
598
+
599
+ def pack_weight(inputs, layout_or_out):
600
+ W = inputs[1]
601
+ new_inputs = list(inputs)
602
+ blocked_w: Union[ir.IRNode, torch.Tensor] = W
603
+ if isinstance(W, ir.IRNode):
604
+ new_size = [padded_n // block_n, k, block_n]
605
+ blocked_w = ir.Buffer(
606
+ W.get_name(), # Borrow the registered buffer name
607
+ ir.FixedLayout(
608
+ W.get_device(),
609
+ W.get_dtype(),
610
+ new_size,
611
+ ir.FlexibleLayout.contiguous_strides(new_size),
612
+ 0,
613
+ ),
614
+ )
615
+ else:
616
+ blocked_w = (
617
+ torch.nn.functional.pad(W, (0, padded_n - n))
618
+ .reshape(k, padded_n // block_n, block_n)
619
+ .transpose(0, 1)
620
+ .contiguous()
621
+ )
622
+ if micro_gemm.get_b_layout() != LayoutType.NORMAL:
623
+ layout_str = (
624
+ "VNNI4"
625
+ if micro_gemm.get_b_layout() == LayoutType.VNNI4
626
+ else "VNNI2"
627
+ )
628
+ assert micro_gemm.get_b_layout() in [
629
+ LayoutType.VNNI2,
630
+ LayoutType.VNNI4,
631
+ ], f"We only support {layout_str} for now"
632
+ vnni_size = (
633
+ 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
634
+ )
635
+ assert (
636
+ k % vnni_size == 0
637
+ ), f"k should be divisible by vnni_size for {layout_str} layout"
638
+ blocked_w = (
639
+ blocked_w.view(
640
+ padded_n // block_n, k // vnni_size, vnni_size, block_n
641
+ )
642
+ .transpose(-1, -2)
643
+ .contiguous()
644
+ .view(padded_n // block_n, k, block_n)
645
+ )
646
+ # normalize stride to be "contiguous_strides" per size
647
+ # this avoids the problems in L.view during template codegen
648
+ new_stride = [1]
649
+ for sz in reversed(blocked_w.shape[1:]):
650
+ new_stride.insert(0, new_stride[0] * sz)
651
+ blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride)
652
+ new_inputs[1] = blocked_w
653
+
654
+ def _is_int8_gemm(inputs):
655
+ return (
656
+ isinstance(inputs[0], ir.IRNode)
657
+ and inputs[0].get_dtype() == torch.uint8
658
+ ) or (
659
+ isinstance(inputs[0], torch.Tensor)
660
+ and inputs[0].dtype == torch.uint8
661
+ )
662
+
663
+ if _is_int8_gemm(new_inputs):
664
+ BCompensate = None
665
+ if isinstance(W, ir.IRNode):
666
+ BCompensate = V.graph.add_tensor_constant(
667
+ V.graph.constants[W.get_name() + "_BMatrixCompens"],
668
+ W.get_name() + "_BMatrixCompens",
669
+ )
670
+ else:
671
+ BCompensate = torch.sum(W.to_dense().to(torch.float), dim=0) # type: ignore[assignment]
672
+ new_inputs.append(BCompensate)
673
+ return new_inputs, layout_or_out
674
+
675
+ def preprocessor(inputs, layout):
676
+ return pack_weight(
677
+ *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
678
+ )
679
+
680
+ def postprocessor(output):
681
+ if isinstance(output, ir.TensorBox):
682
+ # prepack the weight as input to the template buffer
683
+ template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
684
+ assert isinstance(template_buffer, ir.CppTemplateBuffer)
685
+ new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
686
+
687
+ W_node = new_input_nodes[1]
688
+ assert W_node.get_name() in V.graph.constants
689
+ W = V.graph.constants[W_node.get_name()]
690
+ new_input_nodes[1] = W
691
+ new_input_nodes, _ = pack_weight(
692
+ *normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
693
+ )
694
+
695
+ # By using the new packed weight for the GEMM template, we can prune the
696
+ # old weight if it has no other users. This saves memory but makes the FX graph
697
+ # non-retraceable. To support retracing, we can add a repack node to the
698
+ # FX graph. For example:
699
+ # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template
700
+ W_tensor_users = 0
701
+ for node in reversed(V.graph.graph.nodes):
702
+ # Case may happen when the wgt tensor is used by more than 1 get_attr node
703
+ # https://github.com/pytorch/pytorch/issues/134998
704
+ if node.op == "get_attr" and hasattr(
705
+ V.graph.module, node.name
706
+ ): # wgt might already be deleted
707
+ comp_tensor = getattr(V.graph.module, node.name)
708
+ if (
709
+ W.is_mkldnn == comp_tensor.is_mkldnn
710
+ and W.dtype == comp_tensor.dtype
711
+ and W.device == comp_tensor.device
712
+ and (
713
+ (
714
+ not W.is_mkldnn
715
+ and (
716
+ W.untyped_storage().data_ptr()
717
+ == comp_tensor.untyped_storage().data_ptr()
718
+ )
719
+ )
720
+ or (
721
+ W.is_mkldnn
722
+ and (
723
+ torch.ops.mkldnn.data_ptr(W)
724
+ == torch.ops.mkldnn.data_ptr(comp_tensor)
725
+ )
726
+ )
727
+ )
728
+ ):
729
+ W_tensor_users += 1
730
+
731
+ for node in reversed(V.graph.graph.nodes):
732
+ # The wgt tensor has been used by only 1 get_attr node
733
+ # The get_attr node has only 1 user fx node
734
+ if (
735
+ node.name == W_node.get_name()
736
+ and len(node.users) == 1
737
+ and W_tensor_users == 1
738
+ ):
739
+ del V.graph.constants[node.name]
740
+ delattr(V.graph.module, node.name)
741
+ delattr(V.graph.graph.owning_module, node.name)
742
+
743
+ W_packed = new_input_nodes[1]
744
+ W_packed_constant = V.graph.add_tensor_constant(W_packed)
745
+ template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input(
746
+ W_packed_constant
747
+ )
748
+ return output
749
+
750
+ template = DataProcessorTemplateWrapper(
751
+ CppPackedGemmTemplate,
752
+ preprocessor,
753
+ postprocessor,
754
+ input_nodes=input_nodes,
755
+ layout=layout,
756
+ num_threads=num_threads,
757
+ register_blocking=micro_gemm.register_blocking,
758
+ beta=beta,
759
+ alpha=alpha,
760
+ has_bias=has_bias,
761
+ epilogue_creator=epilogue_creator,
762
+ )
763
+ template.maybe_append_choice(choices)
764
+ return template
765
+
766
+ def render( # type: ignore[override,return]
767
+ self,
768
+ kernel: CppTemplateKernel,
769
+ template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
770
+ flag_template_buffer_has_other_users: Optional[bool] = None,
771
+ epilogue_nodes: Optional[List[ir.IRNode]] = None,
772
+ **kwargs,
773
+ ) -> str:
774
+ assert len(self.input_nodes) >= 2
775
+
776
+ int8_gemm = self.input_nodes[0].get_dtype() == torch.uint8
777
+ x_scale = None
778
+ x_zp = None
779
+ w_scale = None
780
+ w_zp = None
781
+ if int8_gemm:
782
+ X, W = self.input_nodes[0], self.input_nodes[1]
783
+ bias_idx = 2 if self.has_bias else 1
784
+ inp = self.input_nodes[bias_idx] if self.has_bias else None
785
+ x_scale = self.input_nodes[bias_idx + 1]
786
+ x_zp = self.input_nodes[bias_idx + 2]
787
+ w_scale = self.input_nodes[bias_idx + 3]
788
+ w_zp = self.input_nodes[bias_idx + 4]
789
+ Y = self.output_node
790
+ else:
791
+ X, W = self.input_nodes[0], self.input_nodes[1]
792
+ Y = self.output_node
793
+ inp = self.input_nodes[2] if self.has_bias else None
794
+
795
+ template_buffer_has_other_users = None
796
+
797
+ if template_buffer_node is not None:
798
+ # Use the updated prepacked weight buffer
799
+ W = template_buffer_node.inputs[1]
800
+ Y = template_buffer_node
801
+
802
+ assert flag_template_buffer_has_other_users is not None
803
+ template_buffer_has_other_users = flag_template_buffer_has_other_users
804
+
805
+ template_buffer = Y
806
+ gemm_output_buffer = template_buffer
807
+
808
+ epilogues: List[ir.IRNode] = []
809
+ reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
810
+ epilogue_creators: List[Callable[[ir.Buffer], ir.Pointwise]] = []
811
+ fake_buffers: List[ir.Buffer] = []
812
+ Y_aliases: Set[str] = set()
813
+
814
+ use_local_acc = (
815
+ self.layout.dtype != torch.float
816
+ or template_buffer_has_other_users
817
+ or int8_gemm
818
+ or self.padded_n != self.n
819
+ or self.maybe_k_slicing()
820
+ )
821
+
822
+ # TODO(jgong5): for int8 gemm, bias-add is handled outside of gemm template,
823
+ # but we'd better move it here to align with fp.
824
+ if inp is not None and self.beta != 0 and not int8_gemm:
825
+ # add an epilogue for bias add
826
+ def _bias_add_epilogue(buf):
827
+ return create_epilogue_with_attr(
828
+ buf, "bias_add", other=inp, beta=self.beta, dtype=self.layout.dtype
829
+ )
830
+
831
+ epilogue_creators.append(_bias_add_epilogue)
832
+
833
+ if self.epilogue_creator is not None:
834
+ epilogue_creators.append(self.epilogue_creator)
835
+
836
+ # When the GEMM output buffer is localized but it has users other than the epilogue nodes,
837
+ # we need to copy the value in the GEMM output local buffer to a global buffer.
838
+ def need_copy_from_local_to_global_buffer_epilogue(
839
+ use_local_acc, template_buffer_has_other_users, epilogue_creators
840
+ ):
841
+ # The GEMM output buffer is a global buffer, thus copy is not needed.
842
+ if not use_local_acc:
843
+ return False
844
+
845
+ # The possible value of template_buffer_has_other_users is (None, False, True)
846
+ # It is None when generating the gemm template during autotune and it will have value during scheduler codegen.
847
+ # extra copy_from_local_to_global_buffer_epilogue is not needed in either of the below two cases:
848
+ # 1. template_buffer_has_other_users is None (i.e. when doing the codegen during autotune)
849
+ # 2. template_buffer_has_other_users is False, which means it's safe to keep the value in the
850
+ # GEMM output buffer in local buffer only (no users outside of the epilogues will use its value).
851
+ if not template_buffer_has_other_users:
852
+ return False
853
+
854
+ # When bias is not None or self.epilogue_creator is not None,
855
+ # there will be epilogue_creators after the GEMM.
856
+ # The GEMM output buffer is localized while
857
+ # the output buffer of the epilogue_creators is a global buffer.
858
+ if epilogue_creators:
859
+ return False
860
+
861
+ return True
862
+
863
+ if need_copy_from_local_to_global_buffer_epilogue(
864
+ use_local_acc, template_buffer_has_other_users, epilogue_creators
865
+ ):
866
+
867
+ def copy_from_local_to_global_buffer_epilogue(input_buffer: ir.Buffer):
868
+ dtype = self.layout.dtype
869
+ input_loader = input_buffer.make_loader()
870
+
871
+ def copy_inner(index):
872
+ input = input_loader(index)
873
+ result = ops.to_dtype(input, dtype)
874
+ return result
875
+
876
+ return ir.Pointwise(
877
+ device=input_buffer.get_device(),
878
+ dtype=self.layout.dtype,
879
+ inner_fn=copy_inner,
880
+ ranges=input_buffer.get_size(),
881
+ )
882
+
883
+ epilogue_creators.append(copy_from_local_to_global_buffer_epilogue)
884
+
885
+ # NOTE [How CPP GEMM template epilogues are organized]
886
+ # gemm_output_buffer
887
+ # --> zero or more in-template epilogues (created by `epilogue_creators`) -->
888
+ # template_buffer
889
+ # --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
890
+ # Y
891
+ if epilogue_creators:
892
+ gemm_output_name = "buf_GemmOut"
893
+ gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout)
894
+ current_input_buffer = gemm_output_buffer
895
+ for i, creator in enumerate(epilogue_creators):
896
+ if i == len(epilogue_creators) - 1:
897
+ buffer_name = template_buffer.get_name()
898
+ else:
899
+ buffer_name = f"buf_GemmOut_epilogue_{i}"
900
+ epilogues.append(
901
+ ir.ComputedBuffer(
902
+ name=buffer_name,
903
+ layout=template_buffer.layout,
904
+ data=creator(current_input_buffer),
905
+ )
906
+ )
907
+ fake_buffers.append(current_input_buffer)
908
+ Y_aliases.add(current_input_buffer.get_name())
909
+ reindexers.append(None)
910
+ if i < len(epilogue_creators) - 1:
911
+ current_input_buffer = ir.Buffer(
912
+ buffer_name, template_buffer.layout
913
+ )
914
+
915
+ Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
916
+
917
+ if epilogue_nodes:
918
+ epilogues.extend(epilogue_nodes)
919
+ assert Y.get_numel() == epilogues[-1].get_numel()
920
+ Y = cast(ir.Buffer, epilogues[-1])
921
+
922
+ if not template_buffer_has_other_users:
923
+ Y_aliases.add(template_buffer.get_name())
924
+
925
+ if (
926
+ Y.get_size() == template_buffer.get_size()
927
+ and Y.get_stride() == template_buffer.get_stride()
928
+ ):
929
+ reindexers.extend([None] * len(epilogue_nodes))
930
+ Y_2d = Y
931
+ else:
932
+
933
+ def get_reindexer(epilogue_node):
934
+ # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example:
935
+ # template_buffer:
936
+ # size (324, 512), stride (512, 1)
937
+ # epilogue_node_ordered (ordered by stride decreasingly, in dense format):
938
+ # size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
939
+ stride_order = list(
940
+ ir.get_stride_order(
941
+ V.graph.sizevars.size_hints(epilogue_node.get_stride())
942
+ )
943
+ )
944
+ fill_order = ir.stride_order2fill_order(stride_order)
945
+ reversed_fill_order = list(reversed(fill_order))
946
+ size_with_stride_ordered_decreasingly = [
947
+ epilogue_node.get_size()[i] for i in reversed_fill_order
948
+ ]
949
+ reshape_reindex = ir.View.dynamic_reshape_indexer(
950
+ size_with_stride_ordered_decreasingly,
951
+ template_buffer.get_size(),
952
+ )
953
+
954
+ # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example:
955
+ # epilogue_node_ordered (ordered by stride decreasingly, in dense format):
956
+ # size (1, 18, 18, 512), stride (165888, 9216, 512, 1)
957
+ # epilogue_node:
958
+ # size (1, 18, 18, 512), stride (165888, 1, 9216, 512)
959
+ from_stride_ordered_decreasingly_to_epilogue_node_order = [
960
+ (len(stride_order) - 1) - stride_order[i]
961
+ for i in range(len(stride_order))
962
+ ]
963
+ stride_reindex = ir.same_reorder(
964
+ from_stride_ordered_decreasingly_to_epilogue_node_order
965
+ )
966
+
967
+ reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex)
968
+ return reindexer
969
+
970
+ reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item]
971
+ if isinstance(Y, ir.BaseView):
972
+ storage = ir.StorageBox(Y.unwrap_view())
973
+ else:
974
+ assert isinstance(Y, ir.Buffer)
975
+ storage = ir.StorageBox(Y)
976
+ Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout())
977
+
978
+ output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
979
+ X.get_dtype()
980
+ )
981
+ micro_gemm = create_micro_gemm(
982
+ f"{kernel.kernel_name}_micro_gemm",
983
+ self.m,
984
+ self.n,
985
+ self.k,
986
+ input_dtype=X.get_dtype(),
987
+ input2_dtype=W.get_dtype(),
988
+ output_dtype=output_dtype,
989
+ compute_dtype=compute_dtype,
990
+ alpha=self.alpha,
991
+ num_threads=self.num_threads,
992
+ )
993
+ assert micro_gemm is not None
994
+ assert self.register_blocking == micro_gemm.register_blocking
995
+ self.log_blockings()
996
+ if isinstance(micro_gemm, CppMicroGemmAMX):
997
+ counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
998
+
999
+ L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
1000
+ assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
1001
+
1002
+ L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
1003
+ assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
1004
+
1005
+ options = dict(
1006
+ X=X,
1007
+ W=W,
1008
+ inp=inp,
1009
+ Y=Y,
1010
+ N=self.n,
1011
+ K=self.k,
1012
+ PADDED_N=self.padded_n,
1013
+ GemmOut=gemm_output_buffer,
1014
+ aliases={alias: Y.get_name() for alias in Y_aliases},
1015
+ beta=self.beta,
1016
+ alpha=self.alpha,
1017
+ num_threads=self.num_threads,
1018
+ micro_gemm=micro_gemm,
1019
+ is_dynamic_M=self.is_dynamic_M,
1020
+ template=self,
1021
+ kernel=kernel,
1022
+ export_declaration=get_export_declaration(),
1023
+ epilogue_nodes=epilogues,
1024
+ reindexers=reindexers,
1025
+ Y_2d=Y_2d,
1026
+ use_local_acc=use_local_acc,
1027
+ maybe_k_slicing=self.maybe_k_slicing(),
1028
+ x_scale=x_scale,
1029
+ x_zp=x_zp,
1030
+ w_scale=w_scale,
1031
+ w_zp=w_zp,
1032
+ acc_buf_dtype=torch.int32 if int8_gemm else torch.float,
1033
+ DTYPE_TO_CPP=DTYPE_TO_CPP,
1034
+ L1_cache_size=L1_cache_size,
1035
+ L2_cache_size=L2_cache_size,
1036
+ config=config,
1037
+ )
1038
+ with contextlib.ExitStack() as stack:
1039
+ for buf in fake_buffers:
1040
+ stack.enter_context(
1041
+ patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
1042
+ )
1043
+ return self._template_from_string(GEMM_TEMPLATE).render(**options)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py ADDED
File without changes
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
34
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
36
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
38
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
39
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
40
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
41
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
42
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
43
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
44
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
45
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
46
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
47
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
48
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
49
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
50
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
51
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
52
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
53
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
54
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
55
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
56
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
57
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
58
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
59
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
60
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
61
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
62
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
63
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
64
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
65
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
66
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
67
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
68
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
69
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
70
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
71
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
72
+ _sfdp_pattern_1_training = MultiOutputPattern([view_default_5,
73
+ view_default_9,
74
+ permute_default_4,
75
+ view_default_11,
76
+ None
77
+ ])
78
+
79
+
80
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
81
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
82
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
83
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
84
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
85
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
86
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
87
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
88
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
89
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
90
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
91
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
92
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
93
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
94
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
95
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
96
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
97
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
98
+ _sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
99
+
100
+
101
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
102
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
103
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
104
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
105
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
106
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
107
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
108
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
109
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
110
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
111
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
112
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
113
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
114
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
115
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
116
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
117
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
118
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
119
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
120
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
121
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
122
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
123
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
124
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
125
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
126
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
127
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
128
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
129
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
130
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
131
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
132
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
133
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
134
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
135
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
136
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
137
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
138
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
139
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
140
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
141
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
142
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
143
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
144
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
145
+ _sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5,
146
+ view_default_9,
147
+ permute_default_4,
148
+ view_default_11,
149
+ None
150
+ ])
151
+
152
+
153
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
154
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
155
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
156
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
157
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
158
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
159
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
160
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
161
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
162
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
163
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
164
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
165
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
166
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
167
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
168
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
169
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
170
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
171
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
172
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
173
+ _sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
34
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
35
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
36
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
37
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
38
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
39
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
40
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
41
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
42
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
43
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
44
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
45
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
46
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
47
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
48
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
49
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
50
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
51
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
54
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
55
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
56
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
57
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
58
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
59
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
60
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
61
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
62
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
63
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
64
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
65
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
66
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
67
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
68
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
69
+ view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
70
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
71
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
72
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
73
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
74
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
75
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
76
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
77
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
78
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
79
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
80
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
81
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
82
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
83
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
84
+ _sfdp_pattern_10_training = MultiOutputPattern([view_default_5,
85
+ permute_default_6,
86
+ permute_default_9,
87
+ permute_default_11
88
+ ])
89
+
90
+
91
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
92
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
93
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
94
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
95
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
96
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
97
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
98
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
99
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
100
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
101
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
102
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
103
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
104
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
105
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
106
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
107
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
108
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
109
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
110
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
111
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
112
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
113
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
114
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
115
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
116
+ _sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
117
+
118
+
119
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
120
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
121
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
122
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
123
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
124
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
125
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
126
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
127
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
128
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
129
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
130
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
131
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
132
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
133
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
134
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
135
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
136
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
137
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
138
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
139
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
140
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
141
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
142
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
143
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
144
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
145
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
146
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
147
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
148
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
149
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
150
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
151
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
152
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
153
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
154
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
155
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
156
+ view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2)
157
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
158
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
159
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
160
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
161
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
162
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
163
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
164
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
165
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
166
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
167
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
168
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
169
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
170
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
171
+ _sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5,
172
+ permute_default_6,
173
+ permute_default_9,
174
+ permute_default_11
175
+ ])
176
+
177
+
178
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
179
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
180
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
181
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
182
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
183
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
184
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
185
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
186
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
187
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
188
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
189
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
190
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
191
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
192
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
193
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
194
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
195
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
196
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
197
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
198
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
199
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
200
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
201
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
202
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
203
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
204
+ _sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
34
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
35
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
36
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
37
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
40
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
41
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
42
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
43
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
44
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
45
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
46
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
47
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
48
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
49
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
50
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
51
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
52
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
53
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
54
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
55
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
56
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
57
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
58
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
59
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
60
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
61
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
62
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
63
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
64
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
65
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
66
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
67
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
68
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
69
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
70
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
71
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
72
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
73
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
74
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
75
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
76
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
77
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
78
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
79
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
80
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
81
+ _sfdp_pattern_11_training = MultiOutputPattern([view_default_5,
82
+ permute_default_6,
83
+ permute_default_9,
84
+ permute_default_11,
85
+ None
86
+ ])
87
+
88
+
89
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
90
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
91
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
92
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
93
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
94
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
95
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
96
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
97
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
98
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
99
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
100
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
101
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
102
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
103
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
104
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
105
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
106
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
107
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
108
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
109
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
110
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
111
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
112
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
113
+ _sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
114
+
115
+
116
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
117
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
118
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
119
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
120
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
121
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
122
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
123
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
124
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
125
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
126
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
127
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
128
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
129
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
130
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
131
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
132
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
133
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
134
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
135
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
136
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
137
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
138
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
139
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
140
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
141
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
142
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
143
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
144
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
145
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
146
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
147
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
148
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
149
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
150
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
151
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
152
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
153
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
154
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
155
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
156
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
157
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
158
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
159
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
160
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
161
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
162
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
163
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
164
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
165
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
166
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
167
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
168
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
169
+ _sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5,
170
+ permute_default_6,
171
+ permute_default_9,
172
+ permute_default_11,
173
+ None
174
+ ])
175
+
176
+
177
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
178
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
179
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
180
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
181
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
182
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
183
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
184
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
185
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
186
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
187
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
188
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
189
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
190
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
191
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
192
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
193
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
194
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
195
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
196
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
197
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
198
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
199
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
200
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
201
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
202
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
203
+ _sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
36
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
38
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
39
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
40
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
41
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
42
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
43
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
44
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
45
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
46
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
47
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
48
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
49
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
50
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
51
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
52
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
53
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
54
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
55
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
56
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
57
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
58
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
59
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
60
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
61
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
62
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
63
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
64
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
65
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
66
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
67
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
68
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
69
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
70
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
71
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
72
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
73
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
74
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
75
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
76
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
77
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
78
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
79
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
80
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
81
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
82
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
83
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
84
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
85
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
86
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
87
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
88
+ _sfdp_pattern_12_training = MultiOutputPattern([view_default_5,
89
+ permute_default_6,
90
+ permute_default_9,
91
+ permute_default_11,
92
+ None,
93
+ None
94
+ ])
95
+
96
+
97
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
98
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
99
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
100
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
101
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
102
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
103
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
104
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
105
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
106
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
107
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
108
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
109
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
110
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
111
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
112
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
113
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
114
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
115
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
116
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
117
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
118
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
119
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
120
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
121
+ _sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
122
+
123
+
124
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
125
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
126
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
127
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
128
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
129
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
130
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
131
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
132
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
133
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
134
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
135
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
136
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
137
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
138
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
139
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
140
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
141
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
142
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
143
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
144
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
145
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
146
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
147
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
148
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
149
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
150
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
151
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
152
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
153
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
154
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
155
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
156
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
157
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
158
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
159
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
160
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
161
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
162
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
163
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
164
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
165
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
166
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
167
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
168
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
169
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
170
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
171
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
172
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
173
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
174
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
175
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
176
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
177
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
178
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
179
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
180
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
181
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
182
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
183
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
184
+ _sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5,
185
+ permute_default_6,
186
+ permute_default_9,
187
+ permute_default_11,
188
+ None,
189
+ None
190
+ ])
191
+
192
+
193
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
194
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
195
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
196
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
197
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
198
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
199
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
200
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
201
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
202
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
203
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
204
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
205
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
206
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
207
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
208
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
209
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
210
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
211
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
212
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
213
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
214
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
215
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
216
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
217
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
218
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
219
+ _sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
36
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
37
+ amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
38
+ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
39
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
40
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
41
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
42
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
43
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
44
+ bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
45
+ neg_default = CallFunction(aten.neg.default, div_Tensor)
46
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
47
+ bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
48
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
49
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
50
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
51
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2)
52
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
53
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2)
54
+ permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
55
+ bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2)
56
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
57
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default)
58
+ permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
59
+ permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
60
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
61
+ _sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1,
62
+ bmm_default_3,
63
+ permute_default_4,
64
+ bmm_default_5,
65
+ None
66
+ ])
67
+
68
+
69
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
70
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2)
71
+ amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True)
72
+ sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default)
73
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
74
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
75
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
76
+ _sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0)
77
+
78
+
79
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
80
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
81
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
82
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
83
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
84
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
85
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
86
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
87
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
88
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
89
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
90
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
91
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2)
92
+ bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value'))
93
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
94
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
95
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
96
+ bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1)
97
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
98
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
99
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2)
100
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
101
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
102
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
103
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
104
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2)
105
+ permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored())
106
+ bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2)
107
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
108
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5)
109
+ permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored())
110
+ permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored())
111
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1'))
112
+ _sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1,
113
+ bmm_default_3,
114
+ permute_default_4,
115
+ bmm_default_5,
116
+ None
117
+ ])
118
+
119
+
120
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
121
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default)
122
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2)
123
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
124
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
125
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
126
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
127
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
128
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
129
+ _sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
34
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
35
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
36
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
37
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
40
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
41
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
42
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
43
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
44
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
45
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
46
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
47
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
48
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
49
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
50
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
51
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
54
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
55
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
56
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
57
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
58
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
59
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
60
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
61
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
62
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
63
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
64
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
65
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
66
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
67
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
68
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
69
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
70
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
71
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
72
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
73
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
74
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
75
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
76
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
77
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
78
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
79
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
80
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
81
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
82
+ _sfdp_pattern_14_training = MultiOutputPattern([view_default_5,
83
+ permute_default_6,
84
+ permute_default_9,
85
+ permute_default_11,
86
+ None,
87
+ None
88
+ ])
89
+
90
+
91
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
92
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
93
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
94
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
95
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
96
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
97
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
98
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
99
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
100
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
101
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
102
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
103
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
104
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
105
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
106
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
107
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
108
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
109
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
110
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
111
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
112
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
113
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
114
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
115
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
116
+ _sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
117
+
118
+
119
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
120
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
121
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
122
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
123
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
124
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
125
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
126
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
127
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
128
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
129
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
130
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
131
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
132
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
133
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
134
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
135
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
136
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
137
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
138
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
139
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
140
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
141
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
142
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
143
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
144
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
145
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
146
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
147
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
148
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
149
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
150
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
151
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
152
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
153
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
154
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
155
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
156
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
157
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
158
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale'))
159
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
160
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
161
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
162
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
163
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
164
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
165
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
166
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
167
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
168
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
169
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
170
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
171
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
172
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
173
+ _sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5,
174
+ permute_default_6,
175
+ permute_default_9,
176
+ permute_default_11,
177
+ None,
178
+ None
179
+ ])
180
+
181
+
182
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
183
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
184
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
185
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
186
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
187
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
188
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
189
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
190
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
191
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
192
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
193
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
194
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
195
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
196
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
197
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
198
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
199
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
200
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
201
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
202
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
203
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
204
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
205
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
206
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
207
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
208
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
209
+ _sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
34
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
35
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
36
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
37
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
38
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
39
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
40
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
41
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
42
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
43
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
44
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
45
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
46
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
47
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
48
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
49
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
50
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
51
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
52
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
53
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
54
+ expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
55
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
56
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
57
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
58
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
59
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
60
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
61
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
62
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
63
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
68
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
69
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
70
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
71
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
72
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
73
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
74
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
75
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
76
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
77
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
78
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
79
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
80
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
81
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
82
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
83
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
84
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
85
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
86
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
87
+ _sfdp_pattern_15_training = MultiOutputPattern([view_default_5,
88
+ permute_default_6,
89
+ permute_default_9,
90
+ permute_default_11,
91
+ None,
92
+ None
93
+ ])
94
+
95
+
96
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
97
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
98
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
99
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
100
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
101
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
102
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
103
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
104
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
105
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
106
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
107
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
108
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
109
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
110
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
111
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
112
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
113
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
114
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
115
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
116
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
117
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
118
+ expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
119
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
120
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
121
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
122
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
123
+ view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
124
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
125
+ _sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
126
+
127
+
128
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
129
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
130
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
131
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
132
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
133
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
134
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
135
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
136
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
137
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
138
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
139
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
140
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
141
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
142
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
143
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
144
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
145
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
146
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
147
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
148
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
149
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
150
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
151
+ expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
152
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
153
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
154
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
155
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
156
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
157
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
158
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
159
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
160
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
161
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
162
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
163
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
164
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
165
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
166
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
167
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
168
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
169
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
170
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
171
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4)
172
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
173
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
174
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
175
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
176
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
177
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
178
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
179
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
180
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
181
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
182
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
183
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
184
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
185
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
186
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
187
+ _sfdp_pattern_15_half_training = MultiOutputPattern([view_default_5,
188
+ permute_default_6,
189
+ permute_default_9,
190
+ permute_default_11,
191
+ None,
192
+ None
193
+ ])
194
+
195
+
196
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
197
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
198
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
199
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
200
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
201
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
202
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
203
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
204
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
205
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
206
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
207
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
208
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
209
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
210
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
211
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
212
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
213
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
214
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
215
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
216
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
217
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
218
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
219
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
220
+ expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
221
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
222
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
223
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
224
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
225
+ view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
226
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
227
+ _sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
36
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
38
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
39
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
40
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
41
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
42
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
43
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
44
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
45
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
46
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
47
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
48
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
49
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
50
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
51
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
52
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
53
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
54
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
55
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
56
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
57
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
58
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
59
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
60
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
61
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
62
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
63
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
68
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
69
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
70
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
71
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
72
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
73
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
74
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
75
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
76
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
77
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
78
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
79
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
80
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
81
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
82
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
83
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
84
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
85
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
86
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
87
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
88
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
89
+ _sfdp_pattern_16_training = MultiOutputPattern([view_default_5,
90
+ permute_default_6,
91
+ permute_default_9,
92
+ permute_default_11,
93
+ None,
94
+ None,
95
+ None
96
+ ])
97
+
98
+
99
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
100
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
101
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
102
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
103
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
104
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
105
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
106
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
107
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
108
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
109
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
110
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
111
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
112
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
113
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
114
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
115
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
116
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
117
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
118
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
119
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
120
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
121
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
122
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
123
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
124
+ _sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
125
+
126
+
127
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
128
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
129
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
130
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
131
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
132
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
133
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
134
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
135
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
136
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
137
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
138
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
139
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
140
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
141
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
142
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
143
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
144
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
145
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
146
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
147
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
148
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
149
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
150
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
151
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
152
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
153
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
154
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
155
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
156
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
157
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
158
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
159
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
160
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
161
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
162
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
163
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
164
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
165
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale'))
166
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
167
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
168
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
169
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
170
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
171
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
172
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
173
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
174
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
175
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
176
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
177
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
178
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
179
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
180
+ _sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5,
181
+ permute_default_6,
182
+ permute_default_9,
183
+ permute_default_11,
184
+ None,
185
+ None,
186
+ None
187
+ ])
188
+
189
+
190
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
191
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
192
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
193
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
194
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
195
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
196
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
197
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
198
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
199
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
200
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
201
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
202
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
203
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
204
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
205
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
206
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
207
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
208
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
209
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
210
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
211
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
212
+ _sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
213
+
214
+
215
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
216
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
217
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
218
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
219
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
220
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
221
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
222
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
223
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
224
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
225
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
226
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
227
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
228
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
229
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
230
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
231
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
232
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
233
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
234
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
235
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
236
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
237
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
238
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
239
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
240
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
241
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
242
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
243
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
244
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
245
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
246
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
247
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
248
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
249
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
250
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
251
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
252
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
253
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
254
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
255
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
256
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
257
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
258
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
259
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
260
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
261
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
262
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
263
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
264
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
265
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
266
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
267
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
268
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
269
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
270
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
271
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
272
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
273
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
274
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
275
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
276
+ _sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5,
277
+ permute_default_6,
278
+ permute_default_9,
279
+ permute_default_11,
280
+ None,
281
+ None,
282
+ None
283
+ ])
284
+
285
+
286
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
287
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
288
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
289
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
290
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
291
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
292
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
293
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
294
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
295
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
296
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
297
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
298
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
299
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
300
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
301
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
302
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
303
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
304
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
305
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
306
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
307
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
308
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
309
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
310
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
311
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
312
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
313
+ _sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
314
+
315
+
316
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
317
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
318
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
319
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
320
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
321
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
322
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
323
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
324
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
325
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
326
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
327
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
328
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
329
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
330
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
331
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
332
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
333
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
334
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
335
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
336
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
337
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
338
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
339
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
340
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
341
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
342
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
343
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
344
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
345
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
346
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
347
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
348
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
349
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
350
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
351
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
352
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
353
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
354
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
355
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
356
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
357
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
358
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
359
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale'))
360
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
361
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
362
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
363
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
364
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
365
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
366
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
367
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
368
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
369
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
370
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
371
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
372
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
373
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
374
+ _sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5,
375
+ permute_default_6,
376
+ permute_default_9,
377
+ permute_default_11,
378
+ None,
379
+ None,
380
+ None
381
+ ])
382
+
383
+
384
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
385
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
386
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
387
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
388
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
389
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
390
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
391
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
392
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
393
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
394
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
395
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
396
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
397
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
398
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
399
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
400
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
401
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
402
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
403
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
404
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
405
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
406
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
407
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
408
+ _sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
409
+
410
+
411
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
412
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
413
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
414
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
415
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
416
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
417
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
418
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
419
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
420
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
421
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
422
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
423
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
424
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
425
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
426
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
427
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
428
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
429
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
430
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
431
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
432
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
433
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
434
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
435
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
436
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
437
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
438
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
439
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
440
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
441
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
442
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
443
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
444
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
445
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
446
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
447
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
448
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
449
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
450
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
451
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
452
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
453
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
454
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
455
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
456
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
457
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
458
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
459
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
460
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
461
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
462
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
463
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
464
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
465
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
466
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
467
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
468
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
469
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
470
+ _sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5,
471
+ permute_default_6,
472
+ permute_default_9,
473
+ permute_default_11,
474
+ None,
475
+ None,
476
+ None
477
+ ])
478
+
479
+
480
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
481
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
482
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
483
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
484
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
485
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
486
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
487
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
488
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
489
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
490
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
491
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
492
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
493
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
494
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
495
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
496
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
497
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
498
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
499
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
500
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
501
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
502
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
503
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
504
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
505
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
506
+ _sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
507
+
508
+
509
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
510
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
511
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
512
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
513
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
514
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
515
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
516
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
517
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
518
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
519
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
520
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
521
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
522
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
523
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
524
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
525
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
526
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
527
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
528
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
529
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
530
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
531
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
532
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
533
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
534
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
535
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
536
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
537
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
538
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
539
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
540
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
541
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
542
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
543
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
544
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored())
545
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2)
546
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
547
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
548
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
549
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
550
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale'))
551
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
552
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
553
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
554
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
555
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
556
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
557
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
558
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
559
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
560
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
561
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
562
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
563
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
564
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
565
+ _sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5,
566
+ permute_default_6,
567
+ permute_default_9,
568
+ permute_default_11,
569
+ None,
570
+ None,
571
+ None
572
+ ])
573
+
574
+
575
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
576
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
577
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
578
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
579
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
580
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
581
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
582
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
583
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
584
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
585
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
586
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
587
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
588
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
589
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
590
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
591
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
592
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
593
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
594
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
595
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
596
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
597
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
598
+ _sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
36
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
37
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
38
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
40
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
41
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
42
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
43
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
44
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
45
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
46
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
47
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
48
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
49
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
50
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
51
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
52
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
53
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
54
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
55
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
56
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
57
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
58
+ expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
59
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
60
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
61
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
62
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
63
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
64
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
65
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
66
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
67
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
68
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
69
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
70
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
71
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
72
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
73
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
74
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
75
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
76
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
77
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
78
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default)
79
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
80
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
81
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
82
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
83
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
84
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
85
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
86
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
87
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
88
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
89
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
90
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
91
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
92
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
93
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
94
+ _sfdp_pattern_17_training = MultiOutputPattern([view_default_5,
95
+ permute_default_6,
96
+ permute_default_9,
97
+ permute_default_11,
98
+ None,
99
+ None,
100
+ None
101
+ ])
102
+
103
+
104
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
105
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
106
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
107
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
108
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
109
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
110
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
111
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
112
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
113
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
114
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
115
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
116
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
117
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
118
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
119
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
120
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2)
121
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
122
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
123
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
124
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
125
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
126
+ expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
127
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
128
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
129
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
130
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
131
+ view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
132
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
133
+ _sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
134
+
135
+
136
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
137
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
138
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
139
+ expand_default = CallFunction(aten.expand.default, eq_Scalar, Ignored(), _users=2)
140
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
141
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
142
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
143
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
144
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
145
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
146
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
147
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
148
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
149
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
150
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
151
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
152
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'))
153
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
154
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
155
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
156
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
157
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
158
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
159
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
160
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
161
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
162
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
163
+ expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
164
+ view_default_3 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
165
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
166
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
167
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
168
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
169
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
170
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
171
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
172
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
173
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
174
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
175
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
176
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
177
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
178
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
179
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
180
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
181
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
182
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
183
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
184
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
185
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
186
+ where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5)
187
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale'))
188
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
189
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
190
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
191
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
192
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
193
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
194
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
195
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
196
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
197
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
198
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
199
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
200
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
201
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
202
+ _sfdp_pattern_17_half_training = MultiOutputPattern([view_default_5,
203
+ permute_default_6,
204
+ permute_default_9,
205
+ permute_default_11,
206
+ None,
207
+ None,
208
+ None
209
+ ])
210
+
211
+
212
+ eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored())
213
+ view_default = CallFunction(aten.view.default, eq_Scalar, Ignored())
214
+ expand_default = CallFunction(aten.expand.default, view_default, Ignored())
215
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
216
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
217
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
218
+ clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
219
+ view_default_1 = CallFunction(aten.view.default, clone_default, Ignored())
220
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
221
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
222
+ expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored())
223
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format)
224
+ view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored())
225
+ bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2)
226
+ view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored())
227
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale'))
228
+ where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor)
229
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
230
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
231
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
232
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
233
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
234
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
235
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
236
+ expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
237
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
238
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
239
+ expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored())
240
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format)
241
+ view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored())
242
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5)
243
+ _sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
36
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
38
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
39
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
40
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
41
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
42
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
43
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
44
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
45
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
46
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
47
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
48
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
49
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
50
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
51
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
52
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
53
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
54
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
55
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
56
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
57
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
58
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
59
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
60
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
61
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
62
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
63
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
64
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
65
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
66
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
67
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
68
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
69
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
70
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
71
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
72
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
73
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
74
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
75
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
76
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
77
+ where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default)
78
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
79
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
80
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
81
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
82
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
83
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
84
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
85
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
86
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
87
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
88
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
89
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
90
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
91
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
92
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
93
+ _sfdp_pattern_18_training = MultiOutputPattern([view_default_5,
94
+ permute_default_1,
95
+ permute_default_3,
96
+ permute_default_6,
97
+ permute_default_9,
98
+ permute_default_11,
99
+ None,
100
+ None
101
+ ])
102
+
103
+
104
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
105
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
106
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
107
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
108
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
109
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
110
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
111
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
112
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
113
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
114
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
115
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
116
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
117
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
118
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
119
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
120
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
121
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
122
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
123
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
124
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
125
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
126
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
127
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
128
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
129
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
130
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
131
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
132
+ _sfdp_pattern_18_inference = MultiOutputPattern([view_default_5,
133
+ permute_default_1,
134
+ permute_default_3
135
+ ])
136
+
137
+
138
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
139
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
140
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
141
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
142
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
143
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
144
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
145
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
146
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
147
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
148
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
149
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
150
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
151
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
152
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
153
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
154
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
155
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
156
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
157
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
158
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
159
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
160
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
161
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
162
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
163
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
164
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
165
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
166
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
167
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
168
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
169
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
170
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
171
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
172
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
173
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
174
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
175
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
176
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
177
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
178
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
179
+ where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default)
180
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
181
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
182
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
183
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
184
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
185
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
186
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
187
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
188
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
189
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
190
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
191
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
192
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
193
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
194
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
195
+ _sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5,
196
+ permute_default_1,
197
+ permute_default_3,
198
+ permute_default_6,
199
+ permute_default_9,
200
+ permute_default_11,
201
+ None,
202
+ None
203
+ ])
204
+
205
+
206
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
207
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
208
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
209
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
210
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
211
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
212
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
213
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
214
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
215
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
216
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
217
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
218
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2)
219
+ amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True)
220
+ sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default)
221
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
222
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
223
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
224
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
225
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
226
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
227
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
228
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
229
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
230
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
231
+ _sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5,
232
+ permute_default_1,
233
+ permute_default_3
234
+ ])
235
+
236
+
237
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
238
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
239
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
240
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
241
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
242
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
243
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
244
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
245
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
246
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
247
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
248
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
249
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
250
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
251
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
252
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
253
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
254
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
255
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
256
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
257
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
258
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
259
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
260
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
261
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
262
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
263
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
264
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
265
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
266
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
267
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
268
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
269
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
270
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
271
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
272
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
273
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
274
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
275
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
276
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
277
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
278
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
279
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
280
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
281
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
282
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
283
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
284
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
285
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
286
+ where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default)
287
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
288
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
289
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
290
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
291
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
292
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
293
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
294
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
295
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
296
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
297
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
298
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
299
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
300
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
301
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
302
+ _sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5,
303
+ permute_default_1,
304
+ permute_default_3,
305
+ permute_default_6,
306
+ permute_default_9,
307
+ permute_default_11,
308
+ None,
309
+ None
310
+ ])
311
+
312
+
313
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
314
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
315
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
316
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
317
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
318
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
319
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
320
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
321
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
322
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
323
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
324
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
325
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
326
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
327
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
328
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
329
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
330
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
331
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
332
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
333
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
334
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
335
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
336
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
337
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
338
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
339
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
340
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
341
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
342
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
343
+ _sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5,
344
+ permute_default_1,
345
+ permute_default_3
346
+ ])
347
+
348
+
349
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
350
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
351
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
352
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
353
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
354
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
355
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
356
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
357
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
358
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
359
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
360
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
361
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
362
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
363
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
364
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
365
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
366
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
367
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
368
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
369
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
370
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
371
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
372
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
373
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
374
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
375
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
376
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
377
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
378
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
379
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
380
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
381
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
382
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
383
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
384
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
385
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
386
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
387
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
388
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
389
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
390
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
391
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
392
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
393
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
394
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
395
+ where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default)
396
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
397
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
398
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
399
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
400
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
401
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
402
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
403
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
404
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
405
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
406
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
407
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
408
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
409
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
410
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
411
+ _sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5,
412
+ permute_default_1,
413
+ permute_default_3,
414
+ permute_default_6,
415
+ permute_default_9,
416
+ permute_default_11,
417
+ None,
418
+ None
419
+ ])
420
+
421
+
422
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
423
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
424
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
425
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2)
426
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
427
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
428
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
429
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
430
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
431
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
432
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
433
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
434
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
435
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2)
436
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
437
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
438
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
439
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
440
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
441
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
442
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
443
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
444
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2)
445
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
446
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
447
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
448
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
449
+ _sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5,
450
+ permute_default_1,
451
+ permute_default_3
452
+ ])
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
36
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
37
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
39
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
40
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
41
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
42
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
43
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
44
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
45
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
46
+ add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
47
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
48
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
49
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
50
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
51
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
52
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
53
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
54
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
55
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
56
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
57
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
58
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
59
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
60
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
61
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
62
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
63
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
64
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
65
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
66
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
67
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
68
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
69
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
70
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
71
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
72
+ where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default)
73
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
74
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
75
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
76
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
77
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
78
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
79
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
80
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
81
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
82
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
83
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
84
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
85
+ _sfdp_pattern_19_training = MultiOutputPattern([view_default_5,
86
+ view_default_9,
87
+ permute_default_4,
88
+ view_default_11,
89
+ None,
90
+ None,
91
+ None
92
+ ])
93
+
94
+
95
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
96
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
97
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
98
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
99
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
100
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
101
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
102
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
103
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
104
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
105
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
106
+ add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
107
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
108
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
109
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
110
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
111
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
112
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
113
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
114
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
115
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
116
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
117
+ _sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
118
+
119
+
120
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
121
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
122
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
123
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
124
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
125
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
126
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
127
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
128
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
129
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2)
130
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
131
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
132
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
133
+ add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
134
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
135
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
136
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
137
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
138
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
139
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
140
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default)
141
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
142
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
143
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
144
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
145
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
146
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
147
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
148
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
149
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
150
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
151
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
152
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
153
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
154
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored())
155
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
156
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
157
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
158
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
159
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
160
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
161
+ scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored())
162
+ where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default)
163
+ div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default)
164
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
165
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
166
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
167
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
168
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
169
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
170
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
171
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
172
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
173
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
174
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
175
+ _sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5,
176
+ view_default_9,
177
+ permute_default_4,
178
+ view_default_11,
179
+ None,
180
+ None,
181
+ None
182
+ ])
183
+
184
+
185
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
186
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
187
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
188
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
189
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
190
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
191
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
192
+ full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
193
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default)
194
+ full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
195
+ where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1)
196
+ add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2)
197
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
198
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
199
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
200
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
201
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
202
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
203
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
204
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
205
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
206
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
207
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
208
+ _sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
34
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
36
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
38
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
39
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
40
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
41
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
42
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
43
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
44
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
45
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
46
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
47
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
48
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
49
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
50
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
51
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
52
+ neg_default = CallFunction(aten.neg.default, div_Tensor)
53
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
54
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
55
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
56
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
57
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2)
58
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
59
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1)
60
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
61
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2)
62
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
63
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
64
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
65
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
66
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
67
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
68
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
69
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
70
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
71
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
72
+ _sfdp_pattern_2_training = MultiOutputPattern([view_default_5,
73
+ view_default_9,
74
+ permute_default_4,
75
+ view_default_11,
76
+ None
77
+ ])
78
+
79
+
80
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
81
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
82
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
83
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
84
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
85
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
86
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
87
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
88
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
89
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
90
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
91
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
92
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
93
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
94
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
95
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
96
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
97
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
98
+ _sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
99
+
100
+
101
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
102
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
103
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
104
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
105
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
106
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
107
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
108
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
109
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
110
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
111
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
112
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
113
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
114
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
115
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
116
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
117
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
118
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
119
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
120
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
121
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
122
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
123
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
124
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
125
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
126
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
127
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
128
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
129
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
130
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True)
131
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1)
132
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
133
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor'))
134
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2)
135
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
136
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
137
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
138
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
139
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
140
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
141
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
142
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
143
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
144
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
145
+ _sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5,
146
+ view_default_9,
147
+ permute_default_4,
148
+ view_default_11,
149
+ None
150
+ ])
151
+
152
+
153
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
154
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
155
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
156
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
157
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
158
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
159
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
160
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
161
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
162
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
163
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
164
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
165
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
166
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
167
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
168
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
169
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
170
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
171
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
172
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
173
+ _sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
36
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
37
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
39
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
40
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
41
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
42
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
43
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
44
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
45
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
46
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
47
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
48
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
49
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
50
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
51
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
52
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
53
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
54
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
55
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
56
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
57
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
58
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
59
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
60
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
61
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
62
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
63
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
64
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
65
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
66
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
67
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor'))
68
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
69
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
70
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
71
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
72
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
73
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
74
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
75
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
76
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
77
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
78
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
79
+ _sfdp_pattern_3_training = MultiOutputPattern([view_default_5,
80
+ view_default_9,
81
+ permute_default_4,
82
+ view_default_11,
83
+ None,
84
+ None
85
+ ])
86
+
87
+
88
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
89
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
90
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
91
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
92
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
93
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
94
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
95
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2)
96
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
97
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
98
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
99
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
100
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
101
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
102
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
103
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
104
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
105
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
106
+ _sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
107
+
108
+
109
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
110
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
111
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
112
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
113
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
114
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
115
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
116
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
117
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
118
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
119
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
120
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
121
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
122
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
123
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
124
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
125
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
126
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
127
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
128
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
129
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
130
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
131
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
132
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
133
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
134
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
135
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
136
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
137
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
138
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
139
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
140
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
141
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
142
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
143
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
144
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
145
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
146
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
147
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
148
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor'))
149
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
150
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
151
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
152
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
153
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
154
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
155
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
156
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
157
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
158
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
159
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
160
+ _sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5,
161
+ view_default_9,
162
+ permute_default_4,
163
+ view_default_11,
164
+ None,
165
+ None
166
+ ])
167
+
168
+
169
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
170
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
171
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
172
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
173
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
174
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
175
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
176
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'))
177
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
178
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
179
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
180
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
181
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
182
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
183
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
184
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
185
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
186
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
187
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
188
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
189
+ _sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
36
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
37
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
39
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
40
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
41
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
42
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
43
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
44
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
45
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
46
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
47
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
48
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor)
49
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
50
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
51
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
52
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
53
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
54
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
55
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
56
+ neg_default = CallFunction(aten.neg.default, div_Tensor)
57
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
58
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
59
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
60
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
61
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
62
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
63
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
64
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2)
65
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
66
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
67
+ mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor'))
68
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2)
69
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
70
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
71
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
72
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
73
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
74
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
75
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
76
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
77
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
78
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
79
+ _sfdp_pattern_4_training = MultiOutputPattern([view_default_5,
80
+ view_default_9,
81
+ permute_default_4,
82
+ view_default_11,
83
+ None,
84
+ None
85
+ ])
86
+
87
+
88
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
89
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
90
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
91
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
92
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
93
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
94
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
95
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2)
96
+ amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True)
97
+ sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default)
98
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
99
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
100
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
101
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored())
102
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
103
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
104
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
105
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
106
+ _sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
107
+
108
+
109
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
110
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
111
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
112
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
113
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
114
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
115
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
116
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
117
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
118
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
119
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
120
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
121
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
122
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
123
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
124
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
125
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
126
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
127
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored())
128
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored())
129
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
130
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
131
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
132
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
133
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
134
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
135
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
136
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
137
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
138
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
139
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
140
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
141
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
142
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3)
143
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored())
144
+ mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
145
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True)
146
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5)
147
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
148
+ mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor'))
149
+ view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2)
150
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
151
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
152
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
153
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
154
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
155
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
156
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
157
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
158
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
159
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
160
+ _sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5,
161
+ view_default_9,
162
+ permute_default_4,
163
+ view_default_11,
164
+ None,
165
+ None
166
+ ])
167
+
168
+
169
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
170
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
171
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
172
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
173
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
174
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
175
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
176
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'))
177
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2)
178
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
179
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
180
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
181
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
182
+ div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
183
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored())
184
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
185
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
186
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
187
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
188
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
189
+ _sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
34
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
36
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
38
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
39
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
40
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
41
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
42
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
43
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
44
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
45
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
46
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
47
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
48
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
49
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
50
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
51
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
52
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
53
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
54
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
55
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
56
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
57
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
58
+ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
59
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
60
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
61
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
62
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
63
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
64
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
65
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
66
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
67
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
68
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
69
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
70
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
71
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
72
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
73
+ _sfdp_pattern_5_training = MultiOutputPattern([view_default_5,
74
+ view_default_9,
75
+ permute_default_4,
76
+ view_default_11,
77
+ None
78
+ ])
79
+
80
+
81
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
82
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
83
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
84
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
85
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
86
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
87
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
88
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
89
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
90
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
91
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
92
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
93
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
94
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
95
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
96
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
97
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
98
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
99
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
100
+ _sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
101
+
102
+
103
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
104
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
105
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
106
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
107
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
108
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
109
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
110
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
111
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
112
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
113
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
114
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
115
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
116
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
117
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
118
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
119
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
120
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
121
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
122
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
123
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
124
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
125
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
126
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
127
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
128
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
129
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
130
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
131
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
132
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2)
133
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
134
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
135
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
136
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
137
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
138
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
139
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
140
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
141
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
142
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
143
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
144
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
145
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
146
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
147
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
148
+ _sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5,
149
+ view_default_9,
150
+ permute_default_4,
151
+ view_default_11,
152
+ None
153
+ ])
154
+
155
+
156
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
157
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
158
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
159
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
160
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
161
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
162
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
163
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
164
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
165
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
166
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
167
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
168
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
169
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
170
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
171
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
172
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
173
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
174
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
175
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
176
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
177
+ _sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
36
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
37
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
39
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
40
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
41
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
42
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
43
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
44
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
45
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
46
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
47
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
48
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
49
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
50
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
51
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
54
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
55
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
56
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
57
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
58
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
59
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
60
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
61
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
62
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
63
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored())
64
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
65
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
66
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
67
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
68
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
69
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
70
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
71
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
72
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
73
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
74
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
75
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
76
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
77
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
78
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
79
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
80
+ _sfdp_pattern_6_training = MultiOutputPattern([view_default_5,
81
+ view_default_9,
82
+ permute_default_4,
83
+ view_default_11,
84
+ None,
85
+ None
86
+ ])
87
+
88
+
89
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
90
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
91
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
92
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
93
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
94
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
95
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
96
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
97
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2)
98
+ amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True)
99
+ sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default)
100
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
101
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
102
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
103
+ expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
104
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
105
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
106
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
107
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
108
+ _sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
109
+
110
+
111
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
112
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
113
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
114
+ view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
115
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
116
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
117
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
118
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
119
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
120
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
121
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
122
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
123
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
124
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
125
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
126
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
127
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
128
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2)
129
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1)
130
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
131
+ expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored())
132
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
133
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
134
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
135
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
136
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
137
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2)
138
+ neg_default = CallFunction(aten.neg.default, convert_element_type_default_2)
139
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
140
+ permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
141
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
142
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
143
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
144
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
145
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2)
146
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored())
147
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2)
148
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
149
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
150
+ convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
151
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored())
152
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
153
+ permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
154
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
155
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
156
+ permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
157
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
158
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
159
+ permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
160
+ permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
161
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
162
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
163
+ _sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5,
164
+ view_default_9,
165
+ permute_default_4,
166
+ view_default_11,
167
+ None,
168
+ None
169
+ ])
170
+
171
+
172
+ expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
173
+ view_default = CallFunction(aten.view.default, expand_default, Ignored())
174
+ permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
175
+ expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
176
+ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
177
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
178
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
179
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
180
+ add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'))
181
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2)
182
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
183
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
184
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
185
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
186
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
187
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
188
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
189
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
190
+ expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
191
+ view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
192
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
193
+ _sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
36
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
37
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
38
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
39
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
40
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
41
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
42
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
43
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
44
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
45
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
46
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
47
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
48
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
49
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
50
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
51
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
52
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
53
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
54
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
55
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
56
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
57
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
58
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
59
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
60
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
61
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
62
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
63
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
68
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
69
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
70
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
71
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
72
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
73
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
74
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
75
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
76
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
77
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
78
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
79
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
80
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
81
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
82
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
83
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
84
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
85
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
86
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
87
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
88
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
89
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
90
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
91
+ _sfdp_pattern_7_training = MultiOutputPattern([view_default_5,
92
+ permute_default_6,
93
+ permute_default_9,
94
+ permute_default_11,
95
+ None
96
+ ])
97
+
98
+
99
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
100
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
101
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
102
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
103
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
104
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
105
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
106
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
107
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
108
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
109
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
110
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
111
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
112
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
113
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
114
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
115
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
116
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
117
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
118
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
119
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
120
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
121
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
122
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
123
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
124
+ _sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
125
+
126
+
127
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
128
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
129
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
130
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
131
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
132
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
133
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
134
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
135
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
136
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
137
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
138
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
139
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
140
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
141
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
142
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
143
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
144
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
145
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
146
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
147
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
148
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
149
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
150
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
151
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
152
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
153
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
154
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
155
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
156
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
157
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
158
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
159
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
160
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
161
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
162
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
163
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
164
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
165
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
166
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
167
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
168
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
169
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
170
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
171
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored())
172
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
173
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
174
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
175
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
176
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
177
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
178
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
179
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
180
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
181
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
182
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
183
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
184
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
185
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
186
+ _sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5,
187
+ permute_default_6,
188
+ permute_default_9,
189
+ permute_default_11,
190
+ None
191
+ ])
192
+
193
+
194
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
195
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
196
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
197
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
198
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
199
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
200
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
201
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
202
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
203
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
204
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
205
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
206
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
207
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
208
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
209
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
210
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
211
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
212
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
213
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
214
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
215
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
216
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
217
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
218
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
219
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
220
+ _sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
34
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
35
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
36
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
37
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
38
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
39
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
40
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
41
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
42
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
43
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
44
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
45
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
46
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
47
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
48
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
49
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
50
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
51
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
52
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
53
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
54
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
55
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
56
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
57
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
58
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
59
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
60
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
61
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
62
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
63
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
64
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
65
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
66
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
67
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
68
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
69
+ div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored())
70
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
71
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
72
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
73
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
74
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
75
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
76
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
77
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
78
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
79
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
80
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
81
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
82
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
83
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
84
+ _sfdp_pattern_8_training = MultiOutputPattern([view_default_5,
85
+ permute_default_6,
86
+ permute_default_9,
87
+ permute_default_11
88
+ ])
89
+
90
+
91
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
92
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
93
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
94
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
95
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
96
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
97
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
98
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
99
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
100
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
101
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
102
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2)
103
+ amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
104
+ sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
105
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
106
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
107
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
108
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
109
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
110
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
111
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
112
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
113
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
114
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
115
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
116
+ _sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
117
+
118
+
119
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
120
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
121
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
122
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
123
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
124
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
125
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
126
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
127
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
128
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
129
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
130
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
131
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
132
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
133
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
134
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
135
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
136
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
137
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
138
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
139
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
140
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
141
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
142
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
143
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
144
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
145
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
146
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
147
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
148
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
149
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
150
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
151
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
152
+ mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2)
153
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
154
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor)
155
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
156
+ div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored())
157
+ view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
158
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
159
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
160
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
161
+ permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored())
162
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
163
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
164
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
165
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
166
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
167
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
168
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
169
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
170
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
171
+ _sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5,
172
+ permute_default_6,
173
+ permute_default_9,
174
+ permute_default_11
175
+ ])
176
+
177
+
178
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
179
+ expand_default = CallFunction(aten.expand.default, permute_default, Ignored())
180
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
181
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
182
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
183
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
184
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
185
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
186
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
187
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
188
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
189
+ div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored())
190
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2)
191
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
192
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
193
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
194
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
195
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
196
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
197
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
198
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
199
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
200
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
201
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
202
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
203
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
204
+ _sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
34
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
35
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
36
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
37
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
38
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
39
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
40
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
41
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
42
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
43
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
44
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
45
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
46
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
47
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
48
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
49
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
50
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
51
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
52
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
53
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
54
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
55
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
56
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
57
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
58
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
59
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
60
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
61
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
62
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
63
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
64
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
65
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
66
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
67
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored())
68
+ view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored())
69
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
70
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
71
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
72
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
73
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
74
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
75
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
76
+ view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2)
77
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
78
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
79
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
80
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
81
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
82
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
83
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
84
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
85
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
86
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
87
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
88
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
89
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
90
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
91
+ _sfdp_pattern_9_training = MultiOutputPattern([view_default_5,
92
+ permute_default_6,
93
+ permute_default_9,
94
+ permute_default_11,
95
+ None
96
+ ])
97
+
98
+
99
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
100
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
101
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
102
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
103
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
104
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
105
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
106
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
107
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
108
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
109
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
110
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2)
111
+ amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True)
112
+ sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default)
113
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
114
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
115
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
116
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
117
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored())
118
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
119
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
120
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
121
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
122
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
123
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
124
+ _sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
125
+
126
+
127
+ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False)
128
+ gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2)
129
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
130
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
131
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
132
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
133
+ view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2)
134
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
135
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
136
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
137
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
138
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2)
139
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
140
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
141
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
142
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
143
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
144
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
145
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
146
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
147
+ mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1)
148
+ mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored())
149
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored())
150
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
151
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
152
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
153
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
154
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
155
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2)
156
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
157
+ view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
158
+ neg_default = CallFunction(aten.neg.default, div_Tensor_1)
159
+ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
160
+ permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored())
161
+ bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4)
162
+ view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
163
+ convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored())
164
+ convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored())
165
+ mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored())
166
+ mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2)
167
+ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2)
168
+ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True)
169
+ fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4)
170
+ convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored())
171
+ view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2)
172
+ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored())
173
+ bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5)
174
+ view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
175
+ div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored())
176
+ permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored())
177
+ permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored())
178
+ bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8)
179
+ view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
180
+ permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored())
181
+ permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored())
182
+ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored())
183
+ bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6)
184
+ view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
185
+ permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored())
186
+ _sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5,
187
+ permute_default_6,
188
+ permute_default_9,
189
+ permute_default_11,
190
+ None
191
+ ])
192
+
193
+
194
+ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored())
195
+ div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored())
196
+ expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored())
197
+ clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format)
198
+ view_default = CallFunction(aten.view.default, clone_default, Ignored())
199
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
200
+ permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored())
201
+ expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored())
202
+ clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format)
203
+ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored())
204
+ bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
205
+ view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
206
+ convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2)
207
+ amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True)
208
+ sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default)
209
+ exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
210
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
211
+ div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
212
+ convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored())
213
+ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored())
214
+ view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
215
+ permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored())
216
+ expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored())
217
+ clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format)
218
+ view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored())
219
+ bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
220
+ _sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'))
34
+ mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta'))
35
+ sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True)
36
+ view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored())
37
+ permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored())
38
+ mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default)
39
+ mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha'))
40
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored())
41
+ mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1'))
42
+ mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha'))
43
+ addmm_pattern_training = MultiOutputPattern([addmm_default,
44
+ view_default,
45
+ mul_Scalar_1,
46
+ mul_Scalar_2,
47
+ None,
48
+ None
49
+ ])
50
+
51
+
52
+ addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'))
34
+ permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored())
35
+ bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default)
36
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored())
37
+ bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1'))
38
+ bmm_pattern_training = MultiOutputPattern([bmm_default,
39
+ bmm_default_1,
40
+ bmm_default_2
41
+ ])
42
+
43
+
44
+ bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ # noqa: F401, E501
4
+ # This is an auto-generated file. Please do not modify it by hand.
5
+ # To re-generate, run:
6
+ # cd ~/pytorch && python torchgen/fuse/gen_patterns.py
7
+
8
+ import torch
9
+ import torch._inductor
10
+
11
+ aten = torch.ops.aten
12
+ prims = torch.ops.prims
13
+
14
+ from torch._inductor.pattern_matcher import (
15
+ Arg,
16
+ CallFunction,
17
+ CallFunctionVarArgs,
18
+ CallMethod,
19
+ CallMethodVarArgs,
20
+ CallModule,
21
+ CallModuleVarArgs,
22
+ ExclusiveKeywordArg,
23
+ Ignored,
24
+ KeywordArg,
25
+ ListOf,
26
+ MultiOutputPattern,
27
+ PatternExpr,
28
+ RepeatedExpr,
29
+ _TargetArgsExpr,
30
+ _TargetExpr,
31
+ _TargetExprVarArgs,
32
+ )
33
+ mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'))
34
+ permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored())
35
+ mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default)
36
+ permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored())
37
+ mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1'))
38
+ mm_pattern_training = MultiOutputPattern([mm_default,
39
+ mm_default_1,
40
+ mm_default_2
41
+ ])
42
+
43
+
44
+ mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0)
.venv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/Lib/site-packages/torch/_inductor/kernel/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import mm, mm_common, mm_plus_mm, unpack_mixed_mm
.venv/Lib/site-packages/torch/_inductor/kernel/bmm.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+
4
+ import torch
5
+
6
+ from .. import ir, lowering as L
7
+ from ..select_algorithm import (
8
+ autotune_select_algorithm,
9
+ ExternKernelChoice,
10
+ TritonTemplate,
11
+ )
12
+ from ..utils import (
13
+ ceildiv as cdiv,
14
+ use_aten_gemm_kernels,
15
+ use_cutlass_template,
16
+ use_triton_template,
17
+ )
18
+ from ..virtualized import V
19
+ from .mm import _is_static_problem
20
+ from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+ aten = torch.ops.aten
25
+
26
+
27
+ def bmm_grid(b, m, n, meta):
28
+ return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
29
+
30
+
31
+ bmm_template = TritonTemplate(
32
+ name="bmm",
33
+ grid=bmm_grid,
34
+ source=r"""
35
+ {{def_kernel("A", "B")}}
36
+ M = {{size("A", -2)}}
37
+ N = {{size("B", -1)}}
38
+ K = {{size("A", -1)}}
39
+
40
+ stride_aq = {{stride("A", 0)}}
41
+ stride_am = {{stride("A", 1)}}
42
+ stride_ak = {{stride("A", 2)}}
43
+
44
+ stride_bq = {{stride("B", 0)}}
45
+ stride_bk = {{stride("B", 1)}}
46
+ stride_bn = {{stride("B", 2)}}
47
+
48
+ # based on triton.ops.matmul
49
+ pid = tl.program_id(0)
50
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
51
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
52
+
53
+ # re-order program ID for better L2 performance
54
+ width = GROUP_M * grid_n
55
+ group_id = pid // width
56
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
57
+ pid_m = group_id * GROUP_M + (pid % group_size)
58
+ pid_n = (pid % width) // (group_size)
59
+
60
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
61
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
62
+ if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
63
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
64
+ else:
65
+ ram = rm % M
66
+ if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
67
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
68
+ else:
69
+ rbn = rn % N
70
+
71
+ rk = tl.arange(0, BLOCK_K)
72
+
73
+ idx_q = tl.program_id(1) # batch dimension for BMM
74
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
75
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
76
+
77
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
78
+ for k in range(K, 0, -BLOCK_K):
79
+ if EVEN_K:
80
+ a = tl.load(A)
81
+ b = tl.load(B)
82
+ else:
83
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
84
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
85
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
86
+ A += BLOCK_K * stride_ak
87
+ B += BLOCK_K * stride_bk
88
+
89
+ # rematerialize rm and rn to save registers
90
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
91
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
92
+ idx_q = tl.program_id(1) # batch dimension for BMM
93
+ idx_m = rm[:, None]
94
+ idx_n = rn[None, :]
95
+ mask = (idx_m < M) & (idx_n < N)
96
+
97
+ # inductor generates a suffix
98
+ {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
99
+ """,
100
+ )
101
+
102
+ aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
103
+ aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
104
+
105
+
106
+ @L.register_lowering(aten.bmm)
107
+ def tuned_bmm(mat1, mat2, *, layout=None):
108
+ if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
109
+ # decompose to small ops when memory bound
110
+ if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
111
+ mat1 = L.unsqueeze(mat1, -1)
112
+ mat2 = L.unsqueeze(mat2, 1)
113
+ return L.sum_(L.mul(mat1, mat2), axis=2)
114
+
115
+ def is_valid_to_require_contiguous(t):
116
+ if not ir.is_storage_and_layout(t):
117
+ return True
118
+ _, layout = ir.as_storage_and_layout(t, freeze=False)
119
+ return isinstance(layout, ir.FlexibleLayout)
120
+
121
+ def is_preferred_layout_as_bmm_input(sizes, strides):
122
+ # contiguous on one of the last two dims
123
+ return (
124
+ strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
125
+ ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
126
+
127
+ # Make the input of bmm contiguous
128
+ # if it is not contiguous on either of the last two dims,
129
+ # because bmm cpu implementation would do contiguous() if not.
130
+ # This is to avoid additional copies in bmm.
131
+ def may_require_contiguous(t, meta_t):
132
+ sizes = meta_t.meta["val"].size()
133
+ strides = meta_t.meta["val"].stride()
134
+ if not is_preferred_layout_as_bmm_input(sizes, strides):
135
+ t = ir.ExternKernel.require_contiguous(t)
136
+ return t
137
+
138
+ if is_valid_to_require_contiguous(mat1):
139
+ meta_mat1 = V.graph.current_node.args[0]
140
+ mat1 = may_require_contiguous(mat1, meta_mat1)
141
+ if is_valid_to_require_contiguous(mat2):
142
+ meta_mat2 = V.graph.current_node.args[1]
143
+ mat2 = may_require_contiguous(mat2, meta_mat2)
144
+
145
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
146
+
147
+ # options to tune from
148
+ choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
149
+ if use_triton_template(layout):
150
+ for config in mm_configs(m, n, k):
151
+ bmm_template.maybe_append_choice(
152
+ choices,
153
+ input_nodes=(mat1, mat2),
154
+ layout=layout,
155
+ **mm_options(config, m, n, k, layout),
156
+ )
157
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
158
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
159
+ from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
160
+
161
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
162
+
163
+ if len(choices) == 0:
164
+ log.warning("No choices for GEMM, using ATen backend as fallback")
165
+ choices.append(aten_bmm.bind((mat1, mat2), layout))
166
+
167
+ return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
168
+
169
+
170
+ # Don't register this since it is slower than decomposing it
171
+ # @L.register_lowering(aten.baddbmm)
172
+ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
173
+ m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
174
+
175
+ # options to tune from
176
+ choices = (
177
+ [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
178
+ if use_aten_gemm_kernels()
179
+ else []
180
+ )
181
+ if use_triton_template(layout):
182
+ for config in mm_configs(m, n, k):
183
+ bmm_template.maybe_append_choice(
184
+ choices,
185
+ input_nodes=(inp, mat1, mat2),
186
+ layout=layout,
187
+ **mm_options(config, m, n, k, layout),
188
+ prefix_args=1,
189
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
190
+ )
191
+
192
+ return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
.venv/Lib/site-packages/torch/_inductor/kernel/conv.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import logging
7
+ from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
8
+
9
+ import torch
10
+
11
+ from .. import config, ir
12
+ from ..lowering import (
13
+ add_layout_constraint,
14
+ constrain_to_fx_strides,
15
+ lowerings as L,
16
+ register_lowering,
17
+ )
18
+ from ..select_algorithm import (
19
+ autotune_select_algorithm,
20
+ ExternKernelChoice,
21
+ TritonTemplate,
22
+ )
23
+ from ..utils import (
24
+ ceildiv,
25
+ is_ones,
26
+ is_zeros,
27
+ pad_listlike,
28
+ sympy_product,
29
+ use_triton_template,
30
+ )
31
+ from ..virtualized import V
32
+ from .mm_common import filtered_configs
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from ..ir import TensorBox
37
+
38
+ log = logging.getLogger(__name__)
39
+
40
+
41
+ aten = torch.ops.aten
42
+
43
+
44
+ def conv2d_grid(n, c, h, w, meta):
45
+ return (
46
+ ceildiv(n * h * w, meta["BLOCK_M"]),
47
+ ceildiv(c, meta["BLOCK_N"]),
48
+ meta["GROUPS"],
49
+ )
50
+
51
+
52
+ def conv3d_grid(n, c, d, h, w, meta):
53
+ return (
54
+ ceildiv(n * d * h * w, meta["BLOCK_M"]),
55
+ ceildiv(c, meta["BLOCK_N"]),
56
+ meta["GROUPS"],
57
+ )
58
+
59
+
60
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
61
+ # will be utilised on the target platform
62
+ kernel_configs = [
63
+ # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
64
+ {"config": (64, 256, 16, 2, 4), "cond": True},
65
+ {"config": (256, 64, 16, 2, 4), "cond": True},
66
+ {"config": (1024, 16, 16, 1, 8), "cond": True},
67
+ {"config": (128, 128, 32, 2, 8), "cond": True},
68
+ {"config": (64, 64, 32, 2, 4), "cond": True},
69
+ {"config": (64, 256, 32, 2, 8), "cond": True},
70
+ {"config": (256, 64, 32, 2, 8), "cond": True},
71
+ ]
72
+
73
+ # Create filtered list of configs based on conv
74
+ platform_configs = tuple(
75
+ cast(Tuple[int, int, int, int, int], config["config"])
76
+ for config in kernel_configs
77
+ if config["cond"]
78
+ )
79
+
80
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
81
+ if torch.version.hip:
82
+ platform_configs = tuple(
83
+ (config[0], config[1], config[2], 1, config[4]) for config in platform_configs
84
+ )
85
+
86
+ conv_configs = functools.partial(
87
+ filtered_configs,
88
+ configs=platform_configs,
89
+ )
90
+
91
+ LOOP_BODY_2D = """
92
+ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
93
+ idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
94
+ idx_x_c = tl.arange(0, BLOCK_K) + k
95
+
96
+ x_ptrs = x_base + (
97
+ (idx_x_h * stride_xh)[:, None]
98
+ + (idx_x_w * stride_xw)[:, None]
99
+ + (idx_x_c * stride_xc)[None, :]
100
+ )
101
+ mask_x = (
102
+ (idx_n < BATCH)[:, None]
103
+ & (idx_x_h >= 0)[:, None]
104
+ & (idx_x_h < IN_H)[:, None]
105
+ & (idx_x_w >= 0)[:, None]
106
+ & (idx_x_w < IN_W)[:, None]
107
+ & (idx_x_c < GROUP_IN_C)[None, :]
108
+ )
109
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
110
+
111
+ w_ptrs = w_base + (
112
+ (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
113
+ )
114
+ mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
115
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
116
+ acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
117
+ """
118
+
119
+ """
120
+ This is a relatively simple conv implementation that can likely be
121
+ improved. Many alternate conv versions can be found here:
122
+ https://github.com/pytorch/torchdynamo/pull/971
123
+ """
124
+ conv2d_template = TritonTemplate(
125
+ name="convolution2d",
126
+ grid=conv2d_grid,
127
+ source=r"""
128
+ {{def_kernel("X", "W")}}
129
+ # Tensor dimensions
130
+ BATCH = {{size("X", 0)}}
131
+ IN_C = {{size("X", 1)}}
132
+ IN_H = {{size("X", 2)}}
133
+ IN_W = {{size("X", 3)}}
134
+ OUT_C = {{size(None, 1)}}
135
+ OUT_H = {{size(None, 2)}}
136
+ OUT_W = {{size(None, 3)}}
137
+
138
+ # Strides:
139
+ stride_xn = {{stride("X", 0)}}
140
+ stride_xc = {{stride("X", 1)}}
141
+ stride_xh = {{stride("X", 2)}}
142
+ stride_xw = {{stride("X", 3)}}
143
+ stride_wc_out = {{stride("W", 0)}}
144
+ stride_wc_in = {{stride("W", 1)}}
145
+ stride_wh = {{stride("W", 2)}}
146
+ stride_ww = {{stride("W", 3)}}
147
+
148
+ nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
149
+ idx_y_w = nhw % OUT_W
150
+ nh = nhw // OUT_W
151
+ idx_y_h = nh % OUT_H
152
+ idx_n = nh // OUT_H
153
+ idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
154
+
155
+ {% if GROUPS == 1 %}
156
+ group = 0
157
+ GROUP_IN_C = IN_C
158
+ GROUP_OUT_C = OUT_C
159
+ {% else %}
160
+ group = tl.program_id(2)
161
+ GROUP_IN_C = IN_C // GROUPS
162
+ GROUP_OUT_C = OUT_C // GROUPS
163
+ {% endif %}
164
+
165
+ x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
166
+ w_base = (
167
+ W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
168
+ )
169
+
170
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
171
+
172
+ {% if UNROLL %}
173
+ {% for i in range(KERNEL_H) %}
174
+ {% for j in range(KERNEL_W) %}
175
+ i = {{i}}
176
+ j = {{j}}
177
+ for k in range(0, GROUP_IN_C, BLOCK_K):
178
+ """
179
+ + LOOP_BODY_2D
180
+ + """
181
+ {% endfor %}
182
+ {% endfor %}
183
+ {% else %}
184
+ # Could be simplified, but slightly slower:
185
+ # for i in range(KERNEL_H):
186
+ # for j in range(KERNEL_W):
187
+ # for k in range(0, GROUP_IN_C, BLOCK_K):
188
+ BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
189
+ for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
190
+ k = (ijk % BLOCK_K_COUNT) * BLOCK_K
191
+ ij = ijk // BLOCK_K_COUNT
192
+ i = ij // KERNEL_W
193
+ j = ij % KERNEL_W
194
+ """
195
+ + LOOP_BODY_2D
196
+ + """
197
+ {% endif %}
198
+
199
+ mask = (
200
+ (idx_n < BATCH)[:, None]
201
+ & (idx_y_h < OUT_H)[:, None]
202
+ & (idx_y_w < OUT_W)[:, None]
203
+ & (idx_y_c < GROUP_OUT_C)[None, :]
204
+ )
205
+ idx_n = idx_n[:, None]
206
+ idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
207
+ idx_h = idx_y_h[:, None]
208
+ idx_w = idx_y_w[:, None]
209
+
210
+ # inductor generates a suffix
211
+ {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
212
+ """,
213
+ )
214
+
215
+ LOOP_BODY_3D = """
216
+ idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D
217
+ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
218
+ idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
219
+ idx_x_c = tl.arange(0, BLOCK_K) + k
220
+
221
+ x_ptrs = x_base + (
222
+ (idx_x_d * stride_xd)[:, None]
223
+ + (idx_x_h * stride_xh)[:, None]
224
+ + (idx_x_w * stride_xw)[:, None]
225
+ + (idx_x_c * stride_xc)[None, :]
226
+ )
227
+ mask_x = (
228
+ (idx_n < BATCH)[:, None]
229
+ & (idx_x_d >= 0)[:, None]
230
+ & (idx_x_d < IN_D)[:, None]
231
+ & (idx_x_h >= 0)[:, None]
232
+ & (idx_x_h < IN_H)[:, None]
233
+ & (idx_x_w >= 0)[:, None]
234
+ & (idx_x_w < IN_W)[:, None]
235
+ & (idx_x_c < GROUP_IN_C)[None, :]
236
+ )
237
+ matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
238
+
239
+ w_ptrs = w_base + (
240
+ (idx_x_c * stride_wc_in)[:, None] +
241
+ (d * stride_wd) + (i * stride_wh) + (j * stride_ww)
242
+ )
243
+ mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
244
+ matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
245
+ acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
246
+ """
247
+
248
+ conv3d_template = TritonTemplate(
249
+ name="convolution3d",
250
+ grid=conv3d_grid,
251
+ source=r"""
252
+ {{def_kernel("X", "W")}}
253
+ # Tensor dimensions
254
+ BATCH = {{size("X", 0)}}
255
+ IN_C = {{size("X", 1)}}
256
+ IN_D = {{size("X", 2)}}
257
+ IN_H = {{size("X", 3)}}
258
+ IN_W = {{size("X", 4)}}
259
+ OUT_C = {{size(None, 1)}}
260
+ OUT_D = {{size(None, 2)}}
261
+ OUT_H = {{size(None, 3)}}
262
+ OUT_W = {{size(None, 4)}}
263
+
264
+ # Strides:
265
+ stride_xn = {{stride("X", 0)}}
266
+ stride_xc = {{stride("X", 1)}}
267
+ stride_xd = {{stride("X", 2)}}
268
+ stride_xh = {{stride("X", 3)}}
269
+ stride_xw = {{stride("X", 4)}}
270
+ stride_wc_out = {{stride("W", 0)}}
271
+ stride_wc_in = {{stride("W", 1)}}
272
+ stride_wd = {{stride("W", 2)}}
273
+ stride_wh = {{stride("W", 3)}}
274
+ stride_ww = {{stride("W", 4)}}
275
+
276
+ ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
277
+ idx_y_w = ndhw % OUT_W
278
+ ndh = ndhw // OUT_W
279
+ idx_y_h = ndh % OUT_H
280
+ nd = ndh // OUT_H
281
+ idx_y_d = nd % OUT_D
282
+ idx_n = nd // OUT_D
283
+ idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
284
+
285
+ {% if GROUPS == 1 %}
286
+ group = 0
287
+ GROUP_IN_C = IN_C
288
+ GROUP_OUT_C = OUT_C
289
+ {% else %}
290
+ group = tl.program_id(2)
291
+ GROUP_IN_C = IN_C // GROUPS
292
+ GROUP_OUT_C = OUT_C // GROUPS
293
+ {% endif %}
294
+
295
+ x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
296
+ w_base = (
297
+ W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
298
+ )
299
+
300
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
301
+
302
+ {% if UNROLL %}
303
+ {% for d in range(KERNEL_D) %}
304
+ {% for i in range(KERNEL_H) %}
305
+ {% for j in range(KERNEL_W) %}
306
+ d = {{d}}
307
+ i = {{i}}
308
+ j = {{j}}
309
+ for k in range(0, GROUP_IN_C, BLOCK_K):
310
+ """
311
+ + LOOP_BODY_3D
312
+ + """
313
+ {% endfor %}
314
+ {% endfor %}
315
+ {% endfor %}
316
+ {% else %}
317
+ # Could be simplified, but slightly slower:
318
+ # for d in range(KERNEL_D):
319
+ # for i in range(KERNEL_H):
320
+ # for j in range(KERNEL_W):
321
+ # for k in range(0, GROUP_IN_C, BLOCK_K):
322
+ BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
323
+ for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
324
+ k = (dijk % BLOCK_K_COUNT) * BLOCK_K
325
+ dij = dijk // BLOCK_K_COUNT
326
+ j = dij % KERNEL_W
327
+ di = dij // KERNEL_W
328
+ i = di % KERNEL_H
329
+ d = di // KERNEL_H
330
+ """
331
+ + LOOP_BODY_3D
332
+ + """
333
+ {% endif %}
334
+
335
+ mask = (
336
+ (idx_n < BATCH)[:, None]
337
+ & (idx_y_d < OUT_D)[:, None]
338
+ & (idx_y_h < OUT_H)[:, None]
339
+ & (idx_y_w < OUT_W)[:, None]
340
+ & (idx_y_c < GROUP_OUT_C)[None, :]
341
+ )
342
+ idx_n = idx_n[:, None]
343
+ idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
344
+ idx_d = idx_y_d[:, None]
345
+ idx_h = idx_y_h[:, None]
346
+ idx_w = idx_y_w[:, None]
347
+
348
+ # inductor generates a suffix
349
+ {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}}
350
+ """,
351
+ )
352
+
353
+ aten_convolution = ExternKernelChoice(
354
+ torch.convolution,
355
+ "at::convolution",
356
+ has_out_variant=False,
357
+ op_overload=aten.convolution.default,
358
+ )
359
+
360
+
361
+ def conv1x1_via_mm(x, w, *, out):
362
+ w = torch.squeeze(torch.squeeze(w, -1), -1)
363
+ return torch.matmul(
364
+ x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
365
+ )
366
+
367
+
368
+ aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
369
+
370
+
371
+ class ConvLayoutParams(TypedDict):
372
+ stride: tuple[int, ...]
373
+ padding: tuple[int, ...]
374
+ dilation: tuple[int, ...]
375
+ transposed: bool
376
+ output_padding: tuple[int, ...]
377
+ groups: int
378
+
379
+
380
+ def conv_layout(
381
+ x: TensorBox,
382
+ weight: TensorBox,
383
+ bias: Optional[TensorBox],
384
+ stride: Sequence[int],
385
+ padding: tuple[int, ...],
386
+ dilation: tuple[int, ...],
387
+ transposed: bool,
388
+ output_padding: tuple[int, ...],
389
+ groups: int,
390
+ ) -> ir.Layout:
391
+ """Determine output layout for a convolution"""
392
+ with V.graph.fake_mode:
393
+ output = torch.ops.aten.convolution(
394
+ ir.ir_node_to_tensor(x, guard_shape=True),
395
+ ir.ir_node_to_tensor(weight, guard_shape=True),
396
+ ir.ir_node_to_tensor(bias, guard_shape=True),
397
+ V.graph.sizevars.size_hints(stride), # type: ignore[arg-type]
398
+ V.graph.sizevars.size_hints(padding), # type: ignore[arg-type]
399
+ V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type]
400
+ transposed,
401
+ V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type]
402
+ groups,
403
+ )
404
+ sizes = ir.convert_shape_to_inductor(output.size())
405
+ stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment]
406
+
407
+ return ir.FixedLayout(
408
+ x.get_device(),
409
+ x.get_dtype(),
410
+ sizes,
411
+ stride,
412
+ )
413
+
414
+
415
+ def channels_last_order(rank):
416
+ order = list(reversed(range(rank)))
417
+ order.insert(1, order.pop(-1))
418
+ return order
419
+
420
+
421
+ def convert_1x1_conv_to_mm(x, weight, bias):
422
+ # special case for 1x1 convolution, which is actually just a matmul
423
+ rank = len(weight.get_size())
424
+ for _ in range(rank - 2):
425
+ weight = L[aten.squeeze](weight, dim=-1)
426
+ weight = L[aten.permute](weight, [1, 0])
427
+
428
+ x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
429
+ x_permute = list(range(rank))
430
+ x_permute.append(x_permute.pop(1))
431
+ x = L[aten.permute](x, x_permute)
432
+ *sizes, in_chan = x.get_size()
433
+ x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
434
+ if bias is None:
435
+ result = L[aten.mm](x, weight)
436
+ else:
437
+ result = L[aten.addmm](bias, x, weight)
438
+ result = L[aten.reshape](result, [*sizes, -1])
439
+ result_permute = list(range(rank))
440
+ result_permute.insert(1, result_permute.pop(-1))
441
+ return L[aten.permute](result, result_permute)
442
+
443
+
444
+ @register_lowering(aten.convolution)
445
+ def convolution(
446
+ x: TensorBox,
447
+ weight: TensorBox,
448
+ bias: TensorBox,
449
+ stride: List[int],
450
+ padding: List[int],
451
+ dilation: List[int],
452
+ transposed: bool,
453
+ output_padding: List[int],
454
+ groups: int,
455
+ ):
456
+ stride = tuple(stride)
457
+ padding = tuple(padding)
458
+ dilation = tuple(dilation)
459
+ output_padding = tuple(output_padding)
460
+ if not isinstance(groups, int):
461
+ groups = V.graph.sizevars.evaluate_static_shape(groups)
462
+ assert isinstance(groups, int)
463
+
464
+ # Need use hint for triton template since the template does not
465
+ # work with a dynamic shape.
466
+ #
467
+ # No need to evaluate_static_shape for dilation and output_padding
468
+ # since the template is only used when dilation is 1 and output_padding
469
+ # is 0.
470
+ stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
471
+ padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
472
+
473
+ kwargs: ConvLayoutParams = {
474
+ "stride": stride,
475
+ "padding": padding,
476
+ "dilation": dilation,
477
+ "transposed": transposed,
478
+ "output_padding": output_padding,
479
+ "groups": groups,
480
+ }
481
+
482
+ if len(x.get_size()) == len(weight.get_size()) - 1:
483
+ # add batch dimension to simplify rest of function
484
+ return L[aten.squeeze](
485
+ convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
486
+ dim=0,
487
+ )
488
+
489
+ out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
490
+ weight.get_size()
491
+ )
492
+ ndim = len(kernel_shape)
493
+ stride = pad_listlike(stride, ndim)
494
+ padding = pad_listlike(padding, ndim)
495
+ dilation = pad_listlike(dilation, ndim)
496
+ output_padding = pad_listlike(output_padding, ndim)
497
+
498
+ def channels_last_conv():
499
+ if V.graph.layout_opt and ndim == 2:
500
+ return True
501
+
502
+ layout = conv_layout(x, weight, None, **kwargs)
503
+ req_stride_order = ir.get_stride_order(
504
+ V.graph.sizevars.size_hints(layout.stride)
505
+ )
506
+ return req_stride_order == ir.NHWC_STRIDE_ORDER
507
+
508
+ autotuning_gemm = config.max_autotune or config.max_autotune_gemm
509
+
510
+ if (
511
+ (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
512
+ and is_ones(kernel_shape)
513
+ and is_ones(stride)
514
+ and is_zeros(padding)
515
+ and is_ones(dilation)
516
+ and not transposed
517
+ and is_zeros(output_padding)
518
+ and groups == 1
519
+ and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
520
+ ):
521
+ return convert_1x1_conv_to_mm(x, weight, bias)
522
+
523
+ if bias is not None and ir.get_device_type(x) != "cpu":
524
+ # peel off the bias, cudnn is slower with it
525
+ result = convolution(x, weight, None, **kwargs)
526
+ return L[aten.add](
527
+ result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
528
+ )
529
+
530
+ x.realize()
531
+ weight.realize()
532
+
533
+ # ndim can be 1 for convolution in models such as demucs
534
+ # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
535
+ # apply channels last.
536
+ if V.graph.layout_opt and ndim == 2:
537
+ V.graph.num_channels_last_conv += 1
538
+ x = ir.ExternKernel.require_channels_last(x)
539
+ # TODO maybe we can convert weights to channels last just once before
540
+ # running the model.
541
+ weight = ir.ExternKernel.require_channels_last(weight)
542
+ layout = conv_layout(x, weight, None, **kwargs)
543
+ else:
544
+ layout = conv_layout(x, weight, None, **kwargs)
545
+ req_stride_order = ir.get_stride_order(
546
+ V.graph.sizevars.size_hints(layout.stride)
547
+ )
548
+ x = ir.ExternKernel.require_stride_order(x, req_stride_order)
549
+ weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
550
+
551
+ ordered_kwargs_for_cpp_kernel = [
552
+ "stride",
553
+ "padding",
554
+ "dilation",
555
+ "transposed",
556
+ "output_padding",
557
+ "groups",
558
+ ]
559
+ if bias is None:
560
+ args = [x, weight]
561
+ kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
562
+ ordered_kwargs_for_cpp_kernel.insert(0, "bias")
563
+ else:
564
+ args = [x, weight, bias]
565
+ bias.realize()
566
+ bias.freeze_layout()
567
+ V.graph.sizevars.evaluate_static_shapes(bias.get_size())
568
+
569
+ choices = []
570
+ if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
571
+ choices = [
572
+ aten_convolution.bind(
573
+ args,
574
+ layout,
575
+ ordered_kwargs_for_cpp_kernel,
576
+ **kwargs,
577
+ )
578
+ ]
579
+
580
+ if (
581
+ torch._inductor.utils._use_conv_autotune_backend("TRITON")
582
+ and use_triton_template(layout)
583
+ # templates only support these:
584
+ and is_ones(dilation)
585
+ and not transposed
586
+ and is_zeros(output_padding)
587
+ # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
588
+ and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type]
589
+ ):
590
+ if (
591
+ is_ones(kernel_shape)
592
+ and is_ones(stride)
593
+ and is_zeros(padding)
594
+ and groups == 1
595
+ ):
596
+ choices.append(aten_conv1x1_via_mm.bind(args, layout))
597
+
598
+ for cfg in conv_configs(
599
+ sympy_product([x.get_size()[0], *x.get_size()[2:]]),
600
+ out_chan,
601
+ in_chan,
602
+ ):
603
+ if ndim == 2:
604
+ conv2d_template.maybe_append_choice(
605
+ choices,
606
+ input_nodes=(x, weight),
607
+ layout=layout,
608
+ KERNEL_H=kernel_shape[0],
609
+ KERNEL_W=kernel_shape[1],
610
+ STRIDE_H=stride[0],
611
+ STRIDE_W=stride[1],
612
+ PADDING_H=padding[0],
613
+ PADDING_W=padding[1],
614
+ GROUPS=groups,
615
+ # TODO(jansel): try unroll for bigger kernels once fixed:
616
+ # https://github.com/openai/triton/issues/1254
617
+ UNROLL=is_ones(kernel_shape),
618
+ ALLOW_TF32=torch.backends.cudnn.allow_tf32,
619
+ num_stages=cfg.num_stages,
620
+ num_warps=cfg.num_warps,
621
+ **cfg.kwargs,
622
+ )
623
+ elif ndim == 3:
624
+ conv3d_template.maybe_append_choice(
625
+ choices,
626
+ input_nodes=(x, weight),
627
+ layout=layout,
628
+ KERNEL_D=kernel_shape[0],
629
+ KERNEL_H=kernel_shape[1],
630
+ KERNEL_W=kernel_shape[2],
631
+ STRIDE_D=stride[0],
632
+ STRIDE_H=stride[1],
633
+ STRIDE_W=stride[2],
634
+ PADDING_D=padding[0],
635
+ PADDING_H=padding[1],
636
+ PADDING_W=padding[2],
637
+ GROUPS=groups,
638
+ # TODO(jansel): try unroll for bigger kernels once fixed:
639
+ # https://github.com/openai/triton/issues/1254
640
+ UNROLL=is_ones(kernel_shape),
641
+ ALLOW_TF32=torch.backends.cudnn.allow_tf32,
642
+ num_stages=cfg.num_stages,
643
+ num_warps=cfg.num_warps,
644
+ **cfg.kwargs,
645
+ )
646
+
647
+ return autotune_select_algorithm("convolution", choices, args, layout)
648
+
649
+
650
+ @register_lowering(aten._convolution)
651
+ def _convolution(
652
+ x,
653
+ weight,
654
+ bias,
655
+ stride,
656
+ padding,
657
+ dilation,
658
+ transposed,
659
+ output_padding,
660
+ groups,
661
+ benchmark,
662
+ deterministic,
663
+ cudnn_enabled,
664
+ allow_tf32,
665
+ ):
666
+ return convolution(
667
+ x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
668
+ )
669
+
670
+
671
+ def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
672
+ assert fx_node.target == torch.ops.aten.convolution.default
673
+ if V.graph.layout_opt:
674
+ return args, kwargs
675
+ else:
676
+ return constrain_to_fx_strides(fx_node, *args, **kwargs)
677
+
678
+
679
+ add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
.venv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py ADDED
@@ -0,0 +1,1843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """ Triton Implementation of the flex_attention Kernel"""
3
+
4
+ import logging
5
+ import math
6
+ from typing import Any, List, Optional, Sequence, Tuple
7
+
8
+ import sympy
9
+
10
+ import torch
11
+ from torch._inductor.virtualized import V
12
+ from torch.utils._pytree import tree_map
13
+
14
+ from .. import config
15
+ from ..ir import (
16
+ ComputedBuffer,
17
+ ExternKernel,
18
+ FixedLayout,
19
+ FlexibleLayout,
20
+ get_stride_order,
21
+ InputBuffer,
22
+ IRNode,
23
+ StorageBox,
24
+ stride_order2fill_order,
25
+ Subgraph,
26
+ TensorBox,
27
+ )
28
+ from ..lowering import empty, empty_strided, lowerings, register_lowering
29
+ from ..select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
30
+
31
+
32
+ log = logging.getLogger(__name__)
33
+ aten = torch.ops.aten
34
+ Expr = sympy.Expr
35
+
36
+
37
+ def construct_strides(
38
+ sizes: Sequence[int],
39
+ fill_order: Sequence[int],
40
+ ) -> Sequence[int]:
41
+ """From a list of sizes and a fill order, construct the strides of the permuted tensor."""
42
+ # Initialize strides
43
+ assert len(sizes) == len(
44
+ fill_order
45
+ ), "Length of sizes must match the length of the fill order"
46
+ strides = [0] * len(sizes)
47
+
48
+ # Start with stride 1 for the innermost dimension
49
+ current_stride = 1
50
+
51
+ # Iterate through the fill order populating strides
52
+ for dim in fill_order:
53
+ strides[dim] = current_stride
54
+ current_stride *= sizes[dim]
55
+
56
+ return strides
57
+
58
+
59
+ def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta):
60
+ """How is this kernel parallelized?
61
+ We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
62
+ Each block is responsible for iterating over blocks of keys and values calculating
63
+ the final attention output.
64
+ """
65
+ import triton
66
+
67
+ return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1)
68
+
69
+
70
+ def create_placeholder(
71
+ name: str, dtype: torch.dtype, device: torch.device
72
+ ) -> TensorBox:
73
+ """Creates a placeholder input buffers for producing subgraph_output."""
74
+ input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], []))
75
+ return TensorBox.create(input_buffer)
76
+
77
+
78
+ def maybe_realize(args: List[Optional[IRNode]]):
79
+ """Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
80
+ return tree_map(lambda x: realize_inputs(x) if x is not None else None, args)
81
+
82
+
83
+ def get_float32_precision():
84
+ if torch.get_float32_matmul_precision() == "highest" or torch.version.hip:
85
+ return "'ieee'"
86
+ else:
87
+ return "'tf32'"
88
+
89
+
90
+ def build_subgraph_buffer(
91
+ args: List[TensorBox],
92
+ subgraph: Subgraph,
93
+ ):
94
+ """This function's goal is to take in the required args and produce the subgraph buffer
95
+ The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
96
+
97
+ Args:
98
+ args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
99
+ subgraph: The Subgraph ir for which to produce the output node
100
+ """
101
+ cnt = 0
102
+ env = {}
103
+ for node in subgraph.graph_module.graph.nodes:
104
+ # There are two classes of placeholder inpts that we need
105
+ # to handle differently. For the first n_scalar_inps inputs
106
+ # we expect that these placeholders were generated by the make_fx call
107
+ # in the flex Attention HOP. So we need to create a new placeholder
108
+ # TensorBox for each of these inputs. For the rest of the inputs we
109
+ # expect that these are lifted inputs that fill up the '*other_buffers'
110
+ # tuple and already have corresponding TensorBoxes passed in as args.
111
+ if node.op == "placeholder":
112
+ env[node] = args[cnt]
113
+ cnt += 1
114
+ elif node.op == "call_function":
115
+ # For call_function we use the default lowerings and pass in the
116
+ # already created TensorBoxes as args
117
+
118
+ args, kwargs = tree_map(
119
+ lambda x: env[x] if x in env else x, (node.args, node.kwargs)
120
+ )
121
+ env[node] = lowerings[node.target](*args, **kwargs)
122
+ elif node.op == "output":
123
+
124
+ def convert_output_node_to_buffer(output):
125
+ if output is None:
126
+ return None
127
+ output_node = output
128
+ output_buffer = env[output_node]
129
+ assert isinstance(output_buffer, TensorBox), (
130
+ "The output node for flex attention's subgraph must be a TensorBox, but got: ",
131
+ type(output_buffer),
132
+ )
133
+ assert isinstance(output_buffer.data, StorageBox), (
134
+ "The output node for the flex attention subgraph must be a StorageBox, but got: ",
135
+ type(output_buffer),
136
+ )
137
+ subgraph_buffer = ComputedBuffer(
138
+ name=None,
139
+ layout=FlexibleLayout(
140
+ device=output_buffer.data.get_device(),
141
+ dtype=output_buffer.data.get_dtype(),
142
+ size=output_buffer.data.get_size(),
143
+ ),
144
+ data=output_buffer.data.data, # type: ignore[arg-type]
145
+ )
146
+ return subgraph_buffer
147
+
148
+ # node.args[0] is either a single element or a list of elements
149
+ # representing all outputs of the function.
150
+ return tree_map(convert_output_node_to_buffer, node.args[0])
151
+
152
+ raise ValueError("FlexAttention was passed a subgraph with no output node!")
153
+
154
+
155
+ # Inner Triton functions shared by flex_attention & split-k decoding kernels.
156
+ compute_next_offset_func = r"""
157
+ @triton.jit
158
+ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK):
159
+ cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
160
+ cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
161
+ next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
162
+ needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
163
+ jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
164
+
165
+ offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
166
+ return offset
167
+ """
168
+
169
+ compute_flex_attention = r"""
170
+ {{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
171
+ # Sub notation for this kernel:
172
+ #
173
+ # Q: Query, K: Key, V: Value
174
+ # M: Number of queries, N: Number of keys/values, D: Model dimension
175
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
176
+ # V_HEAD_DIM: The dimension of the value embeddings
177
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
178
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
179
+ #
180
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
181
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
182
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
183
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
184
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
185
+ #
186
+ # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
187
+ #
188
+ # (Modifiable) Performance tuning options
189
+ # BLOCK_M: The thread block size across the seqlen dim of Q.
190
+ # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
191
+
192
+ # The below are kernel options that can be applied for certain score_mods,
193
+ # or involve a numerics vs. perf tradeoff
194
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
195
+ # about 20% more numerical error, but slightly faster.
196
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
197
+ # is not masked out? If so, we can skip an extra safety check
198
+
199
+ tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
200
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
201
+
202
+ # Define strides of inputs
203
+ stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
204
+ stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
205
+ stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
206
+
207
+ Z = {{size("Q", 0)}}
208
+ HQ = {{size("Q", 1)}}
209
+ Q_LEN = {{size("Q", 2)}}
210
+ KV_LEN = {{size("K", 2)}}
211
+
212
+ MATMUL_PRECISION = Q.dtype.element_ty
213
+
214
+ q_start = tl.program_id(0)
215
+ off_z = tl.program_id(1) // HQ
216
+ off_hq = tl.program_id(1) % HQ
217
+ off_hkv = off_hq // GQA_SHARED_HEADS
218
+ off_g = off_hq % GQA_SHARED_HEADS
219
+
220
+ q_offset = off_z * stride_qz + off_hq * stride_qh
221
+ k_offset = off_z * stride_kz + off_hkv * stride_kh
222
+ v_offset = off_z * stride_vz + off_hkv * stride_vh
223
+
224
+ Q = Q + q_offset
225
+ K = K + k_offset
226
+ V = V + v_offset
227
+
228
+ SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
229
+ SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
230
+
231
+ sparse_idx_z = off_z % SPARSE_Z
232
+ sparse_idx_hq = off_hq % SPARSE_HQ
233
+
234
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
235
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
236
+
237
+ SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE)
238
+ SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
239
+
240
+ # initialize pointer to m and l
241
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
242
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
243
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
244
+
245
+ offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
246
+
247
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
248
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
249
+ sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE
250
+ sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950
251
+
252
+ Q_block_ptr = tl.make_block_ptr(
253
+ base=Q,
254
+ shape=(Q_LEN, QK_HEAD_DIM),
255
+ strides=(stride_qm, stride_qk),
256
+ offsets=(q_start * BLOCK_M, 0),
257
+ block_shape=(BLOCK_M, QK_HEAD_DIM),
258
+ order=(1, 0)
259
+ )
260
+
261
+ # load q: it stays in SRAM throughout the inner loop.
262
+ if IS_DIVISIBLE:
263
+ q = tl.load(Q_block_ptr)
264
+ else:
265
+ # boundary check is not free, so we only do it when necessary.
266
+ q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero")
267
+
268
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
269
+ # We don't know anything "special" about these blocks, so we need to apply
270
+ # both score_mod and mask_mod to it
271
+ kv_indices = KV_IDX + sparse_kv_idx_offset
272
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
273
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
274
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
275
+
276
+ K_block_ptr = tl.make_block_ptr(
277
+ base=K,
278
+ shape=(QK_HEAD_DIM, KV_LEN),
279
+ strides=(stride_kk, stride_kn),
280
+ offsets=(0, kv_start),
281
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
282
+ order=(0, 1)
283
+ )
284
+ V_block_ptr = tl.make_block_ptr(
285
+ base=V,
286
+ shape=(KV_LEN, V_HEAD_DIM),
287
+ strides=(stride_vn, stride_vk),
288
+ offsets=(kv_start, 0),
289
+ block_shape=(BLOCK_N, V_HEAD_DIM),
290
+ order=(1, 0)
291
+ )
292
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
293
+
294
+ acc, l_i, m_i = forward_inner(
295
+ {{gen_argdefs()}},
296
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
297
+ acc, l_i, m_i,
298
+ off_z, off_hq, offs_m[:, None], offs_n[None, :],
299
+ kv_indices, kv_num_blocks,
300
+ 0, block_n_end,
301
+ MATMUL_PRECISION,
302
+ IS_FULL_BLOCKS=False,
303
+ )
304
+
305
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
306
+ # We know these blocks are guaranteed to be "full", so we don't need to
307
+ # apply mask_mod to them - only score_mod
308
+ if HAS_FULL_BLOCKS:
309
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
310
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
311
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
312
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
313
+ block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
314
+
315
+ K_block_ptr = tl.make_block_ptr(
316
+ base=K,
317
+ shape=(QK_HEAD_DIM, KV_LEN),
318
+ strides=(stride_kk, stride_kn),
319
+ offsets=(0, kv_start),
320
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
321
+ order=(0, 1)
322
+ )
323
+ V_block_ptr = tl.make_block_ptr(
324
+ base=V,
325
+ shape=(KV_LEN, V_HEAD_DIM),
326
+ strides=(stride_vn, stride_vk),
327
+ offsets=(kv_start, 0),
328
+ block_shape=(BLOCK_N, V_HEAD_DIM),
329
+ order=(1, 0)
330
+ )
331
+ offs_n = kv_start + tl.arange(0, BLOCK_N)
332
+
333
+ acc, l_i, m_i = forward_inner(
334
+ {{gen_argdefs()}},
335
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
336
+ acc, l_i, m_i,
337
+ off_z, off_hq, offs_m[:, None], offs_n[None, :],
338
+ kv_indices, kv_num_blocks,
339
+ 0, block_n_end,
340
+ MATMUL_PRECISION,
341
+ IS_FULL_BLOCKS=True,
342
+ )
343
+
344
+
345
+ # [Note] Handle fully masked out rows:
346
+ # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
347
+ # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
348
+ l_i = tl.where(l_i == 0.0, 1, l_i)
349
+
350
+ acc = acc / l_i[:, None]
351
+ idx_z = tl.program_id(1) // HQ
352
+ idx_hq = tl.program_id(1) % HQ
353
+ idx_m = offs_m[:, None]
354
+ idx_d = tl.arange(0, V_HEAD_DIM)[None, :]
355
+
356
+ mask = idx_m < Q_LEN
357
+ # TODO generalize and add proper mask support
358
+ {{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
359
+
360
+ # TODO dont want to write this if we dont require grad
361
+ if OUTPUT_LOGSUMEXP:
362
+ off_hz = tl.program_id(1)
363
+ l_ptrs = LSE + off_hz * Q_LEN + offs_m
364
+ lse = m_i + tl.math.log2(l_i)
365
+ if IS_DIVISIBLE:
366
+ tl.store(l_ptrs, lse)
367
+ else:
368
+ tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
369
+ """
370
+
371
+
372
+ compute_forward_inner = r"""
373
+ @triton.jit
374
+ def forward_inner(
375
+ {{gen_argdefs()}},
376
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
377
+ # accumulated values
378
+ acc, l_i, m_i,
379
+ # Offsets used as inputs to score_mod & mask_mod
380
+ # of size [BLOCK_M, BLOCK_N] or scalar.
381
+ off_z, off_h, offs_m, offs_n,
382
+ # blocksparse data
383
+ kv_indices, kv_num_blocks,
384
+ # start kv and end kv block
385
+ block_n_start, block_n_end,
386
+ MATMUL_PRECISION,
387
+ IS_FULL_BLOCKS,
388
+ ):
389
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
390
+ {{gen_defines() | indent_except_first(1)}}
391
+
392
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
393
+ RCP_LN2: tl.constexpr = 1.44269504
394
+
395
+ if PRESCALE_QK:
396
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
397
+
398
+ # loop over k, v and update accumulator until block_n_end
399
+ for start_n in range(block_n_start, block_n_end):
400
+ if IS_DIVISIBLE:
401
+ acc, l_i, m_i = forward_block_mn(
402
+ {{gen_argdefs()}},
403
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
404
+ # accumulated values
405
+ acc, l_i, m_i,
406
+ # Offsets
407
+ off_z, off_h, offs_m, offs_n,
408
+ MATMUL_PRECISION, RCP_LN2,
409
+ IS_FULL_BLOCKS,
410
+ )
411
+ else:
412
+ # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
413
+ # it's on par or slightly faster than only applying to the last block in fwd.
414
+ # However, we choose different strategy for bwd, where we only apply mod & mask
415
+ # to the last block because it's faster a lot.
416
+ acc, l_i, m_i = forward_block_mn(
417
+ {{gen_argdefs()}},
418
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
419
+ # accumulated values
420
+ acc, l_i, m_i,
421
+ # Offsets
422
+ off_z, off_h, offs_m, offs_n,
423
+ MATMUL_PRECISION, RCP_LN2,
424
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
425
+ )
426
+
427
+ # update pointers
428
+ offset = get_offset_for_next_block(
429
+ start_n, kv_indices, kv_num_blocks,
430
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N
431
+ )
432
+
433
+ V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
434
+ K_block_ptr = tl.advance(K_block_ptr, (0, offset))
435
+
436
+ offs_n = offs_n + offset
437
+
438
+ return acc, l_i, m_i
439
+
440
+ """
441
+
442
+
443
+ compute_forward_block_mn = r"""
444
+ @triton.jit
445
+ def forward_block_mn(
446
+ {{gen_argdefs()}},
447
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
448
+ # accumulated values
449
+ acc, l_i, m_i,
450
+ # Offsets
451
+ off_z, off_h, offs_m, offs_n,
452
+ MATMUL_PRECISION, RCP_LN2,
453
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
454
+ ):
455
+ # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
456
+ {{gen_defines() | indent_except_first(1)}}
457
+
458
+ # -- load k --
459
+ if IS_DIVISIBLE:
460
+ k = tl.load(K_block_ptr)
461
+ else:
462
+ k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero")
463
+ # -- compute qk ---
464
+ qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
465
+ if not PRESCALE_QK:
466
+ qk *= SM_SCALE
467
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
468
+ if CHECK_BLOCK_BOUNDARY:
469
+ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
470
+ # which is larger than the actual number of elements. To avoid access memory out of bound,
471
+ # we need to mask out the elements that are out of Q_LEN & KV_LEN.
472
+ m = offs_m % Q_LEN
473
+ n = offs_n % KV_LEN
474
+ else:
475
+ m = offs_m
476
+ n = offs_n
477
+
478
+ {{ modification(
479
+ subgraph_number=0,
480
+ output_name="post_mod_scores",
481
+ score="qk",
482
+ b="off_z",
483
+ h="off_h",
484
+ m="m",
485
+ n="n",
486
+ out="qk"
487
+ ) | indent_except_first(1) }}
488
+
489
+ if CHECK_BLOCK_BOUNDARY:
490
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
491
+ post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
492
+
493
+ if not IS_FULL_BLOCKS:
494
+ {{ modification(
495
+ subgraph_number=1,
496
+ output_name="mask_mod_output",
497
+ score="qk",
498
+ b="off_z",
499
+ h="off_h",
500
+ m="m",
501
+ n="n",
502
+ ) | indent_except_first(2) }}
503
+
504
+ if CHECK_BLOCK_BOUNDARY:
505
+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf"))
506
+ # apply mask for partially unmasked blocks
507
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
508
+
509
+ # TODO: In the case that score_mod is linear, this can be LICMed
510
+ if not PRESCALE_QK:
511
+ post_mod_scores *= RCP_LN2
512
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
513
+
514
+ # -- compute scaling constant ---
515
+ m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
516
+ if not ROWS_GUARANTEED_SAFE:
517
+ masked_out_rows = (m_ij == float("-inf"))
518
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
519
+ else:
520
+ m_ij_masked = m_ij
521
+
522
+ alpha = tl.math.exp2(m_i - m_ij_masked)
523
+ p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
524
+
525
+ # NB: l_i update is pulled up here since it's a bit faster
526
+ # NB: For headdim=256, it's faster to move it back down to after m_i =
527
+ # m_ij
528
+ l_i = l_i * alpha + tl.sum(p, 1)
529
+ # # -- scale and update acc --
530
+ acc = acc * alpha[:, None]
531
+
532
+ if IS_DIVISIBLE:
533
+ v = tl.load(V_block_ptr)
534
+ else:
535
+ v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero")
536
+ acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
537
+
538
+ # -- update m_i
539
+ m_i = m_ij
540
+
541
+ return acc, l_i, m_i
542
+
543
+ """
544
+
545
+
546
+ flex_attention_template = TritonTemplate(
547
+ name="flex_attention",
548
+ grid=flex_attention_grid,
549
+ source=compute_flex_attention
550
+ + compute_forward_inner
551
+ + compute_next_offset_func
552
+ + compute_forward_block_mn,
553
+ )
554
+
555
+
556
+ def _use_flex_decoding(query, kernel_options):
557
+ # Decide which kernel to use, return true if use flex decoding kernel.
558
+ return (
559
+ not kernel_options.get("FORCE_USE_FLEX_ATTENTION", False)
560
+ ) and V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 128))
561
+
562
+
563
+ _h100_default_config = {
564
+ (torch.float32, 64): (128, 32, 4, 3),
565
+ (torch.float32, 128): (32, 64, 4, 3),
566
+ (torch.float32, 256): (32, 32, 4, 3),
567
+ (torch.bfloat16, 64): (128, 128, 4, 3),
568
+ (torch.bfloat16, 128): (128, 64, 8, 3),
569
+ (torch.bfloat16, 256): (64, 32, 4, 3),
570
+ (torch.float16, 64): (128, 128, 4, 3),
571
+ (torch.float16, 128): (128, 128, 8, 3),
572
+ (torch.float16, 256): (64, 32, 4, 3),
573
+ }
574
+
575
+ _a100_default_config = {
576
+ (torch.float32, 64): (128, 32, 4, 3),
577
+ (torch.float32, 128): (128, 32, 4, 3),
578
+ (torch.float32, 256): (64, 16, 4, 3),
579
+ (torch.bfloat16, 64): (128, 64, 4, 3),
580
+ (torch.bfloat16, 128): (128, 64, 8, 3),
581
+ (torch.bfloat16, 256): (32, 64, 4, 3),
582
+ (torch.float16, 64): (128, 64, 4, 3),
583
+ (torch.float16, 128): (128, 64, 8, 3),
584
+ (torch.float16, 256): (32, 64, 4, 3),
585
+ }
586
+
587
+
588
+ def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
589
+ dtype = query.get_dtype()
590
+ head_dim = query.get_size()[-1]
591
+ default_config = None
592
+
593
+ if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
594
+ if dtype == torch.float32:
595
+ default_config = (64, 64, 4, 3)
596
+ else:
597
+ default_config = (128, 64, 4, 3)
598
+ default_config = _h100_default_config.get((dtype, head_dim), default_config)
599
+ elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
600
+ if dtype == torch.float32:
601
+ default_config = (64, 64, 4, 3)
602
+ else:
603
+ default_config = (128, 64, 4, 3)
604
+ default_config = _a100_default_config.get((dtype, head_dim), default_config)
605
+ else: # modest hardware or extremely large head_dim
606
+ if dtype == torch.float32:
607
+ default_config = (32, 16, 4, 3)
608
+ else:
609
+ default_config = (64, 32, 4, 3)
610
+
611
+ return default_config
612
+
613
+
614
+ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
615
+ head_dim = query.get_size()[-1]
616
+ dtype = query.get_dtype()
617
+
618
+ if dtype == torch.float32:
619
+ return (16, 16, 4, 1)
620
+ if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
621
+ if head_dim == 64:
622
+ return (64, 64, 4, 3)
623
+ elif head_dim == 128:
624
+ return (64, 128, 8, 3)
625
+ else:
626
+ return (64, 64, 4, 2)
627
+ elif torch.cuda.get_device_capability() >= (8, 0): # A100
628
+ if head_dim == 64:
629
+ return (32, 128, 4, 3)
630
+ elif head_dim == 128:
631
+ return (64, 128, 8, 3)
632
+ else:
633
+ return (64, 64, 4, 2)
634
+ else: # modest hardware or extremely large head_dim
635
+ return (16, 16, 4, 1)
636
+
637
+
638
+ def create_num_blocks_fake_generator(sparse_indices):
639
+ # The idea here is that we need to create a real tensor with real data
640
+ # that's representative for benchmarking.
641
+ # For example, returning all zeros for the `kv_num_blocks` input would mean
642
+ # that we are computing 0 blocks for each row, which would provide bogus
643
+ # autotuning results.
644
+ #
645
+ # In this case, we choose to use min(16, max_block) blocks, because I
646
+ # (Horace) think it'll probably result in pretty representative performance.
647
+ # If it's too short then prefetching won't help. If it's too long then
648
+ # autotuning will take longer for no good reason.
649
+ def create_num_blocks_fake(x) -> torch.Tensor:
650
+ num_blocks_for_autotuning = min(16, sparse_indices.shape[-1])
651
+ return torch.full(
652
+ x.get_size(),
653
+ int(num_blocks_for_autotuning),
654
+ dtype=x.get_dtype(),
655
+ device=x.get_device(),
656
+ )
657
+
658
+ return create_num_blocks_fake
659
+
660
+
661
+ def create_indices_fake(x) -> torch.Tensor:
662
+ indices = torch.arange(
663
+ 0, int(x.get_size()[-1]), dtype=x.get_dtype(), device=x.get_device()
664
+ )
665
+ indices = indices.expand(x.get_size()).contiguous()
666
+ return indices
667
+
668
+
669
+ from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel
670
+
671
+
672
+ # TODO: We probably also need a layout constraint?
673
+ @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
674
+ def flex_attention(
675
+ query,
676
+ key,
677
+ value,
678
+ subgraph,
679
+ block_mask,
680
+ scale,
681
+ kernel_options,
682
+ score_mod_other_buffers,
683
+ mask_mod_other_buffers,
684
+ ):
685
+ (
686
+ kv_num_blocks,
687
+ kv_indices,
688
+ full_kv_num_blocks,
689
+ full_kv_indices,
690
+ q_num_blocks,
691
+ q_indices,
692
+ full_q_num_blocks,
693
+ full_q_indices,
694
+ SPARSE_KV_BLOCK_SIZE,
695
+ SPARSE_Q_BLOCK_SIZE,
696
+ mask_graph,
697
+ ) = block_mask
698
+ placeholder_inps = [
699
+ create_placeholder(name, dtype, query.get_device())
700
+ for name, dtype in [
701
+ ("score", query.get_dtype()),
702
+ ("b", torch.int32),
703
+ ("h", torch.int32),
704
+ ("m", torch.int32),
705
+ ("n", torch.int32),
706
+ ]
707
+ ]
708
+ subgraph_buffer = build_subgraph_buffer(
709
+ placeholder_inps + list(score_mod_other_buffers), subgraph
710
+ )
711
+ mask_graph_placeholder_inps = [
712
+ create_placeholder(name, dtype, query.get_device())
713
+ for name, dtype in [
714
+ ("b", torch.int32),
715
+ ("h", torch.int32),
716
+ ("m", torch.int32),
717
+ ("n", torch.int32),
718
+ ]
719
+ ]
720
+ mask_graph_buffer = build_subgraph_buffer(
721
+ mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
722
+ )
723
+ kernel_options = dict(kernel_options)
724
+ kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
725
+ if _use_flex_decoding(query, kernel_options):
726
+ return create_flex_decoding_kernel(
727
+ query,
728
+ key,
729
+ value,
730
+ block_mask,
731
+ scale,
732
+ kernel_options,
733
+ subgraph_buffer,
734
+ mask_graph_buffer,
735
+ score_mod_other_buffers,
736
+ mask_mod_other_buffers,
737
+ )
738
+
739
+ (
740
+ query,
741
+ key,
742
+ value,
743
+ kv_num_blocks,
744
+ kv_indices,
745
+ full_kv_num_blocks,
746
+ full_kv_indices,
747
+ q_num_blocks,
748
+ q_indices,
749
+ full_q_num_blocks,
750
+ full_q_indices,
751
+ ) = maybe_realize(
752
+ [
753
+ query,
754
+ key,
755
+ value,
756
+ kv_num_blocks,
757
+ kv_indices,
758
+ full_kv_num_blocks,
759
+ full_kv_indices,
760
+ q_num_blocks,
761
+ q_indices,
762
+ full_q_num_blocks,
763
+ full_q_indices,
764
+ ]
765
+ )
766
+
767
+ Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
768
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
769
+ assert Bq == Bkv, "Batch dimension must match"
770
+ B = Bq
771
+
772
+ if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
773
+ kernel_options.setdefault("IS_DIVISIBLE", False)
774
+ else:
775
+ kernel_options.setdefault("IS_DIVISIBLE", True)
776
+
777
+ # Reuse query strides for output layout despite different last dimension.
778
+ # This works because only the last dim differs and we check it is contiguous.
779
+ q_strides = query.get_stride()
780
+ assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
781
+
782
+ # Construct output layout with strides matching the query.
783
+ out_size = [B, Hq, seq_len_q, v_head_dim]
784
+ stride_order = get_stride_order(query.get_stride())
785
+ fill_order = stride_order2fill_order(stride_order)
786
+ out_strides = construct_strides(out_size, fill_order)
787
+
788
+ layout = FixedLayout(
789
+ query.get_device(),
790
+ query.get_dtype(),
791
+ [B, Hq, seq_len_q, v_head_dim],
792
+ stride=out_strides,
793
+ )
794
+ # see NOTE:[TritonTemplates with multiple outputs]
795
+ logsumexp_shape = [B, Hq, seq_len_q]
796
+ logsumexp = empty_strided(
797
+ logsumexp_shape,
798
+ None,
799
+ dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype
800
+ device=query.get_device(),
801
+ )
802
+ kernel_options.setdefault("SM_SCALE", scale)
803
+
804
+ # Determine GQA broadcast factor.
805
+ gqa_shared_heads = Hq // Hkv
806
+ kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
807
+
808
+ # Inside of Triton kernel, only apply partial masking if partial blocks are computed.
809
+ # full_kv_num_blocks is None if partial blocks are not computed
810
+ has_full_blocks = full_kv_num_blocks is not None
811
+ kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
812
+ if not has_full_blocks:
813
+ full_kv_num_blocks, full_kv_indices = (
814
+ empty(0, device=query.get_device()) for _ in range(2)
815
+ )
816
+ kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
817
+ kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
818
+
819
+ choices: List[Any] = []
820
+ configs: List[Tuple[int, int, int, int]] = []
821
+ configs.append(_get_default_config_fwd(query))
822
+ if config.max_autotune:
823
+ configs += [
824
+ (128, 64, 4, 3),
825
+ (128, 128, 4, 3),
826
+ (128, 128, 8, 2),
827
+ (64, 128, 4, 3),
828
+ (64, 64, 4, 3),
829
+ ]
830
+
831
+ # Note, we don't need to pass in the captured buffers explicitly
832
+ # because they're implicitly added by the score_mod function
833
+ # We do need to explicitly pass it in for autotuning though.
834
+
835
+ for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
836
+ if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0:
837
+ continue
838
+ # Work around https://github.com/pytorch/pytorch/issues/129625
839
+ if num_stages == 2:
840
+ continue
841
+
842
+ # Performance tuning
843
+ kernel_options.setdefault("BLOCK_M", BLOCK_M)
844
+ kernel_options.setdefault("BLOCK_N", BLOCK_N)
845
+ # Blocksparse options
846
+ kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
847
+ kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
848
+
849
+ flex_attention_template.maybe_append_choice(
850
+ choices=choices,
851
+ input_nodes=[
852
+ query,
853
+ key,
854
+ value,
855
+ logsumexp,
856
+ kv_num_blocks,
857
+ kv_indices,
858
+ full_kv_num_blocks,
859
+ full_kv_indices,
860
+ ],
861
+ layout=layout,
862
+ subgraphs=[
863
+ subgraph_buffer,
864
+ mask_graph_buffer,
865
+ ],
866
+ mutated_inputs=[
867
+ logsumexp,
868
+ ],
869
+ num_stages=num_stages,
870
+ num_warps=num_warps,
871
+ call_sizes=query.get_size(),
872
+ **kernel_options,
873
+ )
874
+ inputs_for_autotuning = (
875
+ [
876
+ query,
877
+ key,
878
+ value,
879
+ logsumexp,
880
+ kv_num_blocks,
881
+ kv_indices,
882
+ full_kv_num_blocks,
883
+ full_kv_indices,
884
+ ]
885
+ + list(score_mod_other_buffers)
886
+ + list(mask_mod_other_buffers)
887
+ )
888
+ input_gen_fns = {
889
+ 4: create_num_blocks_fake_generator(kv_indices),
890
+ 5: create_indices_fake,
891
+ 6: create_num_blocks_fake_generator(full_kv_indices),
892
+ 7: create_indices_fake,
893
+ }
894
+ return (
895
+ autotune_select_algorithm(
896
+ "flex_attention",
897
+ choices,
898
+ inputs_for_autotuning,
899
+ layout,
900
+ input_gen_fns=input_gen_fns,
901
+ ),
902
+ logsumexp,
903
+ )
904
+
905
+
906
+ # ---------------------------- Backward HOP Implementation ----------------------------
907
+
908
+
909
+ def flex_attention_backward_grid(
910
+ batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta
911
+ ):
912
+ """How is this kernel parallelized?
913
+ Currently this is only parallelizing over batch* kv_heads, but we can, and want to
914
+ parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size).
915
+ To do this will either require atomic updates to some grad values or to have a two pass kernel design.
916
+ """
917
+ import triton
918
+
919
+ return (
920
+ triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads)
921
+ + triton.cdiv(num_key_value, meta["BLOCK_N1"]),
922
+ 1,
923
+ batch_size * kv_heads,
924
+ )
925
+
926
+
927
+ flex_attention_backward_template = TritonTemplate(
928
+ name="flex_attention_backward",
929
+ grid=flex_attention_backward_grid,
930
+ source=r"""
931
+ {{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
932
+ # Sub notation for this kernel:
933
+ #
934
+ # Q: Query, K: Key, V: Value
935
+ # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
936
+ # DELTA: Precomputed sum(OUT*DO, axis=-1)
937
+ # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
938
+ # DK: Derivative of Key, is the written to via the store_output call due to some limitations with
939
+ # inductor codegen
940
+ # M: Number of queries, N: Number of keys/values
941
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
942
+ # V_HEAD_DIM: The dimension of the value embeddings
943
+ # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
944
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
945
+ # (Modifiable) Performance tuning options
946
+ # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
947
+ # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
948
+ # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
949
+ # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
950
+ #
951
+ # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
952
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
953
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
954
+ # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
955
+ # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
956
+ # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
957
+ # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
958
+ # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
959
+ # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
960
+
961
+ # The below are kernel options that can be applied for certain score_mods,
962
+ # or involve a numerics vs. perf tradeoff
963
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
964
+ # about 20% more numerical error, but slightly faster.
965
+
966
+ # Define strides of inputs
967
+ stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
968
+ stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
969
+ stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
970
+ stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
971
+
972
+ stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
973
+ stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
974
+
975
+ Z = {{size("Q", 0)}}
976
+ HQ = {{size("Q", 1)}}
977
+ HKV = {{size("K", 1)}}
978
+ Q_LEN = {{size("Q", 2)}}
979
+ KV_LEN = {{size("K", 2)}}
980
+
981
+ MATMUL_PRECISION = Q.dtype.element_ty
982
+
983
+ pid = tl.program_id(0)
984
+ NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
985
+ NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
986
+
987
+ off_hz = tl.program_id(2)
988
+ off_z = off_hz // HKV # batch idx
989
+ off_hkv = off_hz % HKV # kv head idx
990
+
991
+ SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
992
+ SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
993
+
994
+ sparse_idx_z = off_z % SPARSE_Z
995
+
996
+ k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64)
997
+ v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64)
998
+ dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64)
999
+
1000
+ # offset K, V, DV pointers for batch/kv-head
1001
+ K += k_adj
1002
+ V += v_adj
1003
+ DV += dv_adj
1004
+
1005
+ RCP_LN2 = 1.44269504
1006
+ offs_k = tl.arange(0, QK_HEAD_DIM)
1007
+ offs_v = tl.arange(0, V_HEAD_DIM)
1008
+
1009
+ if pid >= NUM_KV_BLOCKS:
1010
+ off_pid = pid - NUM_KV_BLOCKS
1011
+ # THIS BLOCK DOES DQ
1012
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
1013
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
1014
+ off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
1015
+ start_m2_block = off_pid % NUM_Q_BLOCKS
1016
+ off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
1017
+ stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
1018
+ stride_kv_idx_h = {{stride("KV_IDX", 1)}}
1019
+ stride_kv_idx_m = {{stride("KV_IDX", 2)}}
1020
+
1021
+ sparse_idx_hq2 = off_hq2 % SPARSE_HQ
1022
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
1023
+
1024
+ sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
1025
+ sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
1026
+
1027
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
1028
+ q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64)
1029
+ do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64)
1030
+ dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64)
1031
+ off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64)
1032
+
1033
+ Q2 = Q + q_adj2
1034
+ DO2 = DO + do_adj2
1035
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
1036
+ # if Q is broadcasted)
1037
+ DQ2 = DQ + dq_adj2
1038
+ LSE2 = LSE + off_chz2
1039
+ DELTA2 = DELTA + off_chz2
1040
+
1041
+ dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
1042
+
1043
+ start_m2 = start_m2_block * BLOCK_M2
1044
+ offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
1045
+
1046
+ # load Q and do: they stay in SRAM throughout the inner loop.
1047
+ if IS_DIVISIBLE:
1048
+ q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd)
1049
+ do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod)
1050
+ else:
1051
+ q = tl.load(Q2 + offs_m2[:, None] * stride_qm + offs_k[None, :] * stride_qd, mask=offs_m2[:, None] < Q_LEN)
1052
+ do = tl.load(DO2 + offs_m2[:, None] * stride_dom + offs_v[None, :] * stride_dod, mask=offs_m2[:, None] < Q_LEN)
1053
+
1054
+ if PRESCALE_QK:
1055
+ q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
1056
+
1057
+ if IS_DIVISIBLE:
1058
+ Di = tl.load(DELTA2 + offs_m2)
1059
+ lse = tl.load(LSE2 + offs_m2)
1060
+ else:
1061
+ Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
1062
+ lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
1063
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
1064
+ lse = lse[:, None]
1065
+
1066
+ # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1067
+ # KV_IDX and KV_NUM_BLKS are always contiguous.
1068
+ kv_indices = KV_IDX + sparse_kv_idx_offset
1069
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
1070
+ sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
1071
+
1072
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
1073
+ dq = bwd_dq_inner(
1074
+ {{gen_argdefs()}},
1075
+ K, V,
1076
+ dq, q, do, Di, lse,
1077
+ off_z, off_hq2, offs_m2, offs_n2,
1078
+ stride_kn, stride_kd, stride_vn, stride_vd,
1079
+ kv_indices, sparse_kv_num_blocks,
1080
+ MATMUL_PRECISION,
1081
+ IS_FULL_BLOCKS=False,
1082
+ )
1083
+
1084
+ if HAS_FULL_BLOCKS:
1085
+ # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1086
+ # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
1087
+ kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
1088
+ kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
1089
+ sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
1090
+
1091
+ offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
1092
+ dq = bwd_dq_inner(
1093
+ {{gen_argdefs()}},
1094
+ K, V,
1095
+ dq, q, do, Di, lse,
1096
+ off_z, off_hq2, offs_m2, offs_n2,
1097
+ stride_kn, stride_kd, stride_vn, stride_vd,
1098
+ kv_indices, sparse_kv_num_blocks,
1099
+ MATMUL_PRECISION,
1100
+ IS_FULL_BLOCKS=True,
1101
+ )
1102
+
1103
+ # Write back dQ.
1104
+ dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
1105
+ dq *= SM_SCALE
1106
+ if IS_DIVISIBLE:
1107
+ tl.store(dq_ptrs, dq)
1108
+ else:
1109
+ tl.store(dq_ptrs, dq, mask=offs_m2[:, None] < Q_LEN)
1110
+ else:
1111
+ # THIS BLOCK DOES DK & DV
1112
+ SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
1113
+ SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
1114
+
1115
+ pid_mask = pid // SPARSE_KV_MULTIPLE
1116
+
1117
+ stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
1118
+ stride_q_idx_h = {{stride("Q_IDX", 1)}}
1119
+ stride_q_idx_n = {{stride("Q_IDX", 2)}}
1120
+
1121
+ dv = tl.zeros([BLOCK_N1, V_HEAD_DIM], dtype=tl.float32)
1122
+ dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM], dtype=tl.float32)
1123
+
1124
+ start_n1 = pid * BLOCK_N1
1125
+ offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
1126
+
1127
+ # load K and V: they stay in SRAM throughout the inner loop.
1128
+ if IS_DIVISIBLE:
1129
+ k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd)
1130
+ v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd)
1131
+ else:
1132
+ k = tl.load(K + offs_n1[:, None] * stride_kn + offs_k[None, :] * stride_kd, mask=offs_n1[:, None] < KV_LEN)
1133
+ v = tl.load(V + offs_n1[:, None] * stride_vn + offs_v[None, :] * stride_vd, mask=offs_n1[:, None] < KV_LEN)
1134
+ if PRESCALE_QK:
1135
+ k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
1136
+
1137
+ for off_g in range(0, GQA_SHARED_HEADS):
1138
+ off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
1139
+
1140
+ # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
1141
+ q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64)
1142
+ do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64)
1143
+ dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64)
1144
+ off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64)
1145
+
1146
+ Q1 = Q + q_adj1
1147
+ DO1 = DO + do_adj1
1148
+ # TODO: This does not work if DQ is not the same layout as Q (for example,
1149
+ # if Q is broadcasted)
1150
+ LSE1 = LSE + off_chz1
1151
+ DELTA1 = DELTA + off_chz1
1152
+
1153
+ sparse_idx_hq1 = off_hq1 % SPARSE_HQ
1154
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
1155
+
1156
+ sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
1157
+ sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
1158
+
1159
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1160
+ # Q_IDX and Q_NUM_BLKS are always contiguous.
1161
+ q_indices = Q_IDX + sparse_q_idx_offset
1162
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
1163
+ sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
1164
+
1165
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
1166
+ dk, dv = bwd_dkdv_inner(
1167
+ {{gen_argdefs()}},
1168
+ Q1, DO1, DELTA1, LSE1,
1169
+ dk, dv, k, v,
1170
+ off_z, off_hq1, offs_n1, offs_m1,
1171
+ stride_qm, stride_qd, stride_dom, stride_dod,
1172
+ q_indices, sparse_q_num_blocks,
1173
+ MATMUL_PRECISION,
1174
+ IS_FULL_BLOCKS=False,
1175
+ )
1176
+
1177
+
1178
+ if HAS_FULL_BLOCKS:
1179
+ # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1180
+ # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
1181
+ q_indices = FULL_Q_IDX + sparse_q_idx_offset
1182
+ q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
1183
+ sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
1184
+
1185
+ offs_m1 = q_start + tl.arange(0, BLOCK_M1)
1186
+ dk, dv = bwd_dkdv_inner(
1187
+ {{gen_argdefs()}},
1188
+ Q1, DO1, DELTA1, LSE1,
1189
+ dk, dv, k, v,
1190
+ off_z, off_hq1, offs_n1, offs_m1,
1191
+ stride_qm, stride_qd, stride_dom, stride_dod,
1192
+ q_indices, sparse_q_num_blocks,
1193
+ MATMUL_PRECISION,
1194
+ IS_FULL_BLOCKS=True,
1195
+ )
1196
+
1197
+ # Write back dV and dK.
1198
+ dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
1199
+
1200
+ index_n = offs_n1[:, None]
1201
+ index_k = offs_k[None, :]
1202
+
1203
+ if IS_DIVISIBLE:
1204
+ tl.store(dv_ptrs, dv)
1205
+ else:
1206
+ tl.store(dv_ptrs, dv, mask=index_n < KV_LEN)
1207
+
1208
+ dk *= SM_SCALE
1209
+ mask = index_n < KV_LEN
1210
+ {{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
1211
+
1212
+ @triton.jit
1213
+ def bwd_dq_inner(
1214
+ {{gen_argdefs()}},
1215
+ K, V, # pointers
1216
+ dq, q, do, Di, lse,
1217
+ off_z, off_hq, offs_m2, offs_n2,
1218
+ stride_kn, stride_kd, stride_vn, stride_vd,
1219
+ kv_indices, sparse_kv_num_blocks,
1220
+ MATMUL_PRECISION,
1221
+ IS_FULL_BLOCKS,
1222
+ ):
1223
+ {{gen_defines() | indent_except_first(1) }}
1224
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
1225
+ RCP_LN2: tl.constexpr = 1.44269504
1226
+ Q_LEN = {{size("Q", 2)}}
1227
+ KV_LEN = {{size("K", 2)}}
1228
+
1229
+ offs_k = tl.arange(0, QK_HEAD_DIM)
1230
+ offs_v = tl.arange(0, V_HEAD_DIM)
1231
+
1232
+ kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
1233
+ vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
1234
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
1235
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
1236
+
1237
+ hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
1238
+ if not IS_DIVISIBLE:
1239
+ if hi >= 1:
1240
+ for start_n in range(0, hi - 1):
1241
+ dq = bwd_dq_block_mn(
1242
+ {{gen_argdefs()}},
1243
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1244
+ off_z, off_hq, offs_m2, offs_n2,
1245
+ stride_kn, stride_kd, stride_vn, stride_vd,
1246
+ kv_indices, sparse_kv_num_blocks,
1247
+ MATMUL_PRECISION, RCP_LN2,
1248
+ IS_FULL_BLOCKS,
1249
+ )
1250
+
1251
+ # Increment pointers.
1252
+ offset = get_offset_for_next_block(
1253
+ start_n, kv_indices, sparse_kv_num_blocks,
1254
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
1255
+ )
1256
+
1257
+ kT_ptrs += offset * stride_kn
1258
+ vT_ptrs += offset * stride_vn
1259
+
1260
+ offs_n2 += offset
1261
+
1262
+ dq = bwd_dq_block_mn(
1263
+ {{gen_argdefs()}},
1264
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1265
+ off_z, off_hq, offs_m2, offs_n2,
1266
+ stride_kn, stride_kd, stride_vn, stride_vd,
1267
+ kv_indices, sparse_kv_num_blocks,
1268
+ MATMUL_PRECISION, RCP_LN2,
1269
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
1270
+ )
1271
+ else:
1272
+ for start_n in range(0, hi):
1273
+ dq = bwd_dq_block_mn(
1274
+ {{gen_argdefs()}},
1275
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1276
+ off_z, off_hq, offs_m2, offs_n2,
1277
+ stride_kn, stride_kd, stride_vn, stride_vd,
1278
+ kv_indices, sparse_kv_num_blocks,
1279
+ MATMUL_PRECISION, RCP_LN2,
1280
+ IS_FULL_BLOCKS,
1281
+ )
1282
+
1283
+ # Increment pointers.
1284
+ offset = get_offset_for_next_block(
1285
+ start_n, kv_indices, sparse_kv_num_blocks,
1286
+ SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2
1287
+ )
1288
+
1289
+ kT_ptrs += offset * stride_kn
1290
+ vT_ptrs += offset * stride_vn
1291
+
1292
+ offs_n2 += offset
1293
+
1294
+ return dq
1295
+
1296
+
1297
+ @triton.jit
1298
+ def bwd_dq_block_mn(
1299
+ {{gen_argdefs()}},
1300
+ dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
1301
+ off_z, off_hq, offs_m2, offs_n2,
1302
+ stride_kn, stride_kd, stride_vn, stride_vd,
1303
+ kv_indices, sparse_kv_num_blocks,
1304
+ MATMUL_PRECISION, RCP_LN2,
1305
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
1306
+ ):
1307
+ {{gen_defines() | indent_except_first(1)}}
1308
+
1309
+ if IS_DIVISIBLE:
1310
+ kT = tl.load(kT_ptrs)
1311
+ else:
1312
+ kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN)
1313
+ qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
1314
+ if not PRESCALE_QK:
1315
+ qk *= SM_SCALE
1316
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
1317
+ pre_mod_scores = qk
1318
+ if CHECK_BLOCK_BOUNDARY:
1319
+ m = offs_m2[:, None] % Q_LEN
1320
+ n = offs_n2[None, :] % KV_LEN
1321
+ else:
1322
+ m = offs_m2[:, None]
1323
+ n = offs_n2[None, :]
1324
+ {{ modification(
1325
+ subgraph_number=0,
1326
+ output_name="post_mod_scores",
1327
+ score="qk",
1328
+ b="off_z",
1329
+ h="off_hq",
1330
+ m="m",
1331
+ n="n",
1332
+ out="qk"
1333
+ ) | indent_except_first(1) }}
1334
+
1335
+ if CHECK_BLOCK_BOUNDARY:
1336
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
1337
+ post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
1338
+
1339
+ if not IS_FULL_BLOCKS:
1340
+ {{ modification(
1341
+ subgraph_number=2,
1342
+ output_name="mask_mod_output",
1343
+ score="qk",
1344
+ b="off_z",
1345
+ h="off_hq",
1346
+ m="m",
1347
+ n="n",
1348
+ ) | indent_except_first(2) }}
1349
+
1350
+ if CHECK_BLOCK_BOUNDARY:
1351
+ mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
1352
+ # apply mask for partial masked block
1353
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
1354
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1355
+ if not PRESCALE_QK:
1356
+ post_mod_scores *= RCP_LN2
1357
+ p = tl.math.exp2(post_mod_scores - lse)
1358
+ # Compute dP and dS.
1359
+ if IS_DIVISIBLE:
1360
+ vT = tl.load(vT_ptrs)
1361
+ else:
1362
+ vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN)
1363
+ dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
1364
+ ds = p * (dp - Di[:, None])
1365
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
1366
+ {{ modification(
1367
+ subgraph_number=1,
1368
+ output_name = "grad_scores",
1369
+ score="pre_mod_scores",
1370
+ b="off_z",
1371
+ h="off_hq",
1372
+ m="m",
1373
+ n="n",
1374
+ grad_score_mod="ds"
1375
+ ) | indent_except_first(1) }}
1376
+ if CHECK_BLOCK_BOUNDARY:
1377
+ grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
1378
+
1379
+ ds = grad_scores
1380
+
1381
+ if not IS_FULL_BLOCKS:
1382
+ if CHECK_BLOCK_BOUNDARY:
1383
+ mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf"))
1384
+ # (grads) apply mask for partially unmasked block
1385
+ ds = tl.where(mask_mod_output, ds, 0.0)
1386
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1387
+ ds = ds.to(MATMUL_PRECISION)
1388
+ # Compute dQ.
1389
+ dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
1390
+
1391
+ return dq
1392
+
1393
+
1394
+ @triton.jit
1395
+ def bwd_dkdv_inner(
1396
+ {{gen_argdefs()}},
1397
+ Q, DO, DELTA, LSE, # pointers
1398
+ dk, dv, k, v,
1399
+ off_z, off_hq, offs_n1, offs_m1,
1400
+ stride_qm, stride_qd, stride_dom, stride_dod,
1401
+ q_indices, sparse_q_num_blocks,
1402
+ MATMUL_PRECISION,
1403
+ IS_FULL_BLOCKS,
1404
+ ):
1405
+ {{gen_defines() | indent_except_first(1) }}
1406
+ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
1407
+ RCP_LN2: tl.constexpr = 1.44269504
1408
+ Q_LEN = {{size("Q", 2)}}
1409
+ KV_LEN = {{size("K", 2)}}
1410
+
1411
+ offs_k = tl.arange(0, QK_HEAD_DIM)
1412
+ offs_v = tl.arange(0, V_HEAD_DIM)
1413
+
1414
+ qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
1415
+ do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
1416
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
1417
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
1418
+ hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
1419
+
1420
+ if not IS_DIVISIBLE:
1421
+ if hi >= 1:
1422
+ for start_m in range(0, hi - 1):
1423
+ dk, dv = bwd_dkdv_block_mn(
1424
+ {{gen_argdefs()}},
1425
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1426
+ off_z, off_hq, offs_n1, offs_m1,
1427
+ stride_qm, stride_qd, stride_dom, stride_dod,
1428
+ q_indices, sparse_q_num_blocks,
1429
+ MATMUL_PRECISION, RCP_LN2,
1430
+ IS_FULL_BLOCKS,
1431
+ )
1432
+ # Increment pointers.
1433
+ offset = get_offset_for_next_block(
1434
+ start_m, q_indices, sparse_q_num_blocks,
1435
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
1436
+ )
1437
+
1438
+ qT_ptrs += offset * stride_qm
1439
+ do_ptrs += offset * stride_dom
1440
+
1441
+ offs_m1 += offset
1442
+
1443
+ dk, dv = bwd_dkdv_block_mn(
1444
+ {{gen_argdefs()}},
1445
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1446
+ off_z, off_hq, offs_n1, offs_m1,
1447
+ stride_qm, stride_qd, stride_dom, stride_dod,
1448
+ q_indices, sparse_q_num_blocks,
1449
+ MATMUL_PRECISION, RCP_LN2,
1450
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
1451
+ )
1452
+ else:
1453
+ for start_m in range(0, hi):
1454
+ dk, dv = bwd_dkdv_block_mn(
1455
+ {{gen_argdefs()}},
1456
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1457
+ off_z, off_hq, offs_n1, offs_m1,
1458
+ stride_qm, stride_qd, stride_dom, stride_dod,
1459
+ q_indices, sparse_q_num_blocks,
1460
+ MATMUL_PRECISION, RCP_LN2,
1461
+ IS_FULL_BLOCKS,
1462
+ )
1463
+ # Increment pointers.
1464
+ offset = get_offset_for_next_block(
1465
+ start_m, q_indices, sparse_q_num_blocks,
1466
+ SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1
1467
+ )
1468
+
1469
+ qT_ptrs += offset * stride_qm
1470
+ do_ptrs += offset * stride_dom
1471
+
1472
+ offs_m1 += offset
1473
+
1474
+ return dk, dv
1475
+
1476
+
1477
+ @triton.jit
1478
+ def bwd_dkdv_block_mn(
1479
+ {{gen_argdefs()}},
1480
+ dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
1481
+ off_z, off_hq, offs_n1, offs_m1,
1482
+ stride_qm, stride_qd, stride_dom, stride_dod,
1483
+ q_indices, sparse_q_num_blocks,
1484
+ MATMUL_PRECISION, RCP_LN2,
1485
+ IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
1486
+ ):
1487
+ {{gen_defines() | indent_except_first(1) }}
1488
+
1489
+ # Load LSE before computing qk to reduce pipeline stall.
1490
+ if IS_DIVISIBLE:
1491
+ qT = tl.load(qT_ptrs)
1492
+ lse = tl.load(LSE + offs_m1)
1493
+ else:
1494
+ qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN)
1495
+ lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
1496
+ lse = tl.where(lse == -float("inf"), 0.0, lse)
1497
+ qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
1498
+ if not PRESCALE_QK:
1499
+ qkT *= SM_SCALE
1500
+ # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
1501
+ if CHECK_BLOCK_BOUNDARY:
1502
+ m = offs_m1[None, :] % Q_LEN
1503
+ n = offs_n1[:, None] % KV_LEN
1504
+ else:
1505
+ m = offs_m1[None, :]
1506
+ n = offs_n1[:, None]
1507
+ pre_mod_scores = qkT
1508
+ {{ modification(
1509
+ subgraph_number=0,
1510
+ output_name="post_mod_scores",
1511
+ score="qkT",
1512
+ b="off_z",
1513
+ h="off_hq",
1514
+ m="m",
1515
+ n="n",
1516
+ out="qkT"
1517
+ ) | indent_except_first(1) }}
1518
+
1519
+ if CHECK_BLOCK_BOUNDARY:
1520
+ # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
1521
+ post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
1522
+
1523
+ if not IS_FULL_BLOCKS:
1524
+ {{ modification(
1525
+ subgraph_number=2,
1526
+ output_name="mask_mod_output",
1527
+ score="qkT",
1528
+ b="off_z",
1529
+ h="off_hq",
1530
+ m="m",
1531
+ n="n",
1532
+ ) | indent_except_first(2) }}
1533
+ if CHECK_BLOCK_BOUNDARY:
1534
+ mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
1535
+ # (grads) apply mask for fully masked block
1536
+ post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
1537
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1538
+ if not PRESCALE_QK:
1539
+ post_mod_scores *= RCP_LN2
1540
+ pT = tl.math.exp2(post_mod_scores - lse[None, :])
1541
+ if IS_DIVISIBLE:
1542
+ do = tl.load(do_ptrs)
1543
+ else:
1544
+ do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN)
1545
+ # Compute dV.
1546
+ ppT = pT
1547
+ dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
1548
+ if IS_DIVISIBLE:
1549
+ Di = tl.load(DELTA + offs_m1)
1550
+ else:
1551
+ Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
1552
+ # Compute dP and dS.
1553
+ dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
1554
+ dsT = pT * (dpT - Di[None, :])
1555
+ # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
1556
+ {{ modification(
1557
+ subgraph_number=1,
1558
+ output_name = "grad_scores",
1559
+ score="pre_mod_scores",
1560
+ b="off_z",
1561
+ h="off_hq",
1562
+ m="m",
1563
+ n="n",
1564
+ grad_score_mod="dsT"
1565
+ ) | indent_except_first(1) }}
1566
+ if CHECK_BLOCK_BOUNDARY:
1567
+ grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
1568
+
1569
+ dsT = grad_scores
1570
+ if not IS_FULL_BLOCKS:
1571
+ if CHECK_BLOCK_BOUNDARY:
1572
+ mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf"))
1573
+ # (grads) apply mask for partially unmasked block
1574
+ dsT = tl.where(mask_mod_output, dsT, 0.0)
1575
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1576
+ dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
1577
+
1578
+ return dk, dv
1579
+ """
1580
+ + compute_next_offset_func,
1581
+ )
1582
+
1583
+
1584
+ # TODO: We probably also need a layout constraint?
1585
+ @register_lowering(
1586
+ torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
1587
+ )
1588
+ def flex_attention_backward(*args, **kwargs):
1589
+ (
1590
+ query,
1591
+ key,
1592
+ value,
1593
+ out,
1594
+ logsumexp,
1595
+ grad_out,
1596
+ grad_logsumexp,
1597
+ fw_graph,
1598
+ joint_graph,
1599
+ block_mask,
1600
+ scale,
1601
+ kernel_options,
1602
+ score_mod_other_buffers,
1603
+ mask_mod_other_buffers,
1604
+ ) = args
1605
+ (
1606
+ kv_num_blocks,
1607
+ kv_indices,
1608
+ full_kv_num_blocks,
1609
+ full_kv_indices,
1610
+ q_num_blocks,
1611
+ q_indices,
1612
+ full_q_num_blocks,
1613
+ full_q_indices,
1614
+ SPARSE_KV_BLOCK_SIZE,
1615
+ SPARSE_Q_BLOCK_SIZE,
1616
+ mask_graph,
1617
+ ) = block_mask
1618
+
1619
+ (
1620
+ query,
1621
+ key,
1622
+ value,
1623
+ grad_out,
1624
+ kv_num_blocks,
1625
+ kv_indices,
1626
+ full_kv_num_blocks,
1627
+ full_kv_indices,
1628
+ q_num_blocks,
1629
+ q_indices,
1630
+ full_q_num_blocks,
1631
+ full_q_indices,
1632
+ ) = maybe_realize(
1633
+ [
1634
+ query,
1635
+ key,
1636
+ value,
1637
+ grad_out,
1638
+ kv_num_blocks,
1639
+ kv_indices,
1640
+ full_kv_num_blocks,
1641
+ full_kv_indices,
1642
+ q_num_blocks,
1643
+ q_indices,
1644
+ full_q_num_blocks,
1645
+ full_q_indices,
1646
+ ]
1647
+ )
1648
+
1649
+ device = query.get_device()
1650
+ dtype = query.get_dtype()
1651
+ Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
1652
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
1653
+ assert Bq == Bkv, "Batch dimension must match"
1654
+ B = Bq
1655
+
1656
+ kernel_options = dict(kernel_options)
1657
+ kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision())
1658
+ if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
1659
+ kernel_options.setdefault("IS_DIVISIBLE", False)
1660
+ else:
1661
+ kernel_options.setdefault("IS_DIVISIBLE", True)
1662
+
1663
+ fwd_placeholder_inps = [
1664
+ create_placeholder(name, dtype, device)
1665
+ for name, dtype in [
1666
+ ("score", dtype),
1667
+ ("b", torch.int32),
1668
+ ("h", torch.int32),
1669
+ ("m", torch.int32),
1670
+ ("n", torch.int32),
1671
+ ]
1672
+ ]
1673
+ fw_subgraph_buffer = build_subgraph_buffer(
1674
+ fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
1675
+ )
1676
+
1677
+ joint_placeholder_inps = fwd_placeholder_inps + [
1678
+ create_placeholder("grad_score_mod", dtype, device)
1679
+ ]
1680
+ joint_subgraph_buffer, *_ = build_subgraph_buffer(
1681
+ joint_placeholder_inps + list(score_mod_other_buffers), joint_graph
1682
+ )
1683
+
1684
+ mask_graph_placeholder_inps = [
1685
+ create_placeholder(name, dtype, query.get_device())
1686
+ for name, dtype in [
1687
+ ("b", torch.int32),
1688
+ ("h", torch.int32),
1689
+ ("m", torch.int32),
1690
+ ("n", torch.int32),
1691
+ ]
1692
+ ]
1693
+ mask_graph_buffer = build_subgraph_buffer(
1694
+ mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
1695
+ )
1696
+
1697
+ layout_k = FixedLayout(
1698
+ key.get_device(),
1699
+ key.get_dtype(),
1700
+ key.get_size(),
1701
+ key.get_stride(),
1702
+ )
1703
+
1704
+ # Create delta which will is needed for the bwd's kernel
1705
+ grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2))
1706
+ mul_delta = lowerings[aten.mul](out, grad_out)
1707
+ delta = lowerings[aten.sum](mul_delta, axis=-1)
1708
+ delta = lowerings[aten.sub](delta, grad_lse_exp2)
1709
+ delta = ExternKernel.require_contiguous(delta)
1710
+
1711
+ grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta])
1712
+
1713
+ # see NOTE:[TritonTemplates with multiple outputs]
1714
+ grad_query = empty_strided(
1715
+ query.get_size(), query.get_stride(), dtype=dtype, device=device
1716
+ )
1717
+ grad_value = empty_strided(
1718
+ value.get_size(), value.get_stride(), dtype=dtype, device=device
1719
+ )
1720
+
1721
+ kernel_options.setdefault("SM_SCALE", scale)
1722
+
1723
+ # Determine GQA factor
1724
+ gqa_shared_heads = Hq // Hkv
1725
+ kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
1726
+
1727
+ # Inside of Triton kernel, only apply partial masking if partial blocks are computed.
1728
+ # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed.
1729
+ has_full_blocks = full_kv_num_blocks is not None
1730
+ kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
1731
+ if not has_full_blocks:
1732
+ full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = (
1733
+ empty(0, device=query.get_device()) for _ in range(4)
1734
+ )
1735
+ kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
1736
+ kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
1737
+
1738
+ choices: List[Any] = []
1739
+ configs: List[Tuple[int, int, int, int]] = []
1740
+ configs.append(_get_default_config_bwd(query))
1741
+ if config.max_autotune:
1742
+ configs.extend(
1743
+ [
1744
+ (BLOCK1, BLOCK2, w, s)
1745
+ for BLOCK1 in [32, 64]
1746
+ for BLOCK2 in [32, 64, 128]
1747
+ for w in [4, 8]
1748
+ for s in [1, 3, 4, 5]
1749
+ if BLOCK2 % BLOCK1 == 0
1750
+ ]
1751
+ )
1752
+
1753
+ for BLOCK1, BLOCK2, num_warps, num_stages in configs:
1754
+ if (
1755
+ SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0
1756
+ or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0
1757
+ or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0
1758
+ or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
1759
+ ):
1760
+ continue
1761
+
1762
+ # Performance tuning
1763
+ kernel_options.setdefault("BLOCK_M1", BLOCK1)
1764
+ kernel_options.setdefault("BLOCK_N1", BLOCK2)
1765
+ kernel_options.setdefault("BLOCK_M2", BLOCK2)
1766
+ kernel_options.setdefault("BLOCK_N2", BLOCK1)
1767
+ # Blocksparse options
1768
+ kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
1769
+ kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
1770
+
1771
+ flex_attention_backward_template.maybe_append_choice(
1772
+ choices=choices,
1773
+ input_nodes=[
1774
+ query,
1775
+ key,
1776
+ value,
1777
+ logsumexp,
1778
+ delta,
1779
+ grad_out,
1780
+ grad_query,
1781
+ grad_value,
1782
+ kv_num_blocks,
1783
+ kv_indices,
1784
+ q_num_blocks,
1785
+ q_indices,
1786
+ full_kv_num_blocks,
1787
+ full_kv_indices,
1788
+ full_q_num_blocks,
1789
+ full_q_indices,
1790
+ ],
1791
+ layout=layout_k, # We use store_output only for grad_key
1792
+ subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer],
1793
+ mutated_inputs=[grad_query, grad_value],
1794
+ call_sizes=query.get_size() + key.get_size()[1:3],
1795
+ num_stages=num_stages,
1796
+ num_warps=num_warps,
1797
+ **kernel_options,
1798
+ )
1799
+ inputs_for_autotuning = (
1800
+ [
1801
+ query,
1802
+ key,
1803
+ value,
1804
+ logsumexp,
1805
+ delta,
1806
+ grad_out,
1807
+ grad_query,
1808
+ grad_value,
1809
+ kv_num_blocks,
1810
+ kv_indices,
1811
+ q_num_blocks,
1812
+ q_indices,
1813
+ full_kv_num_blocks,
1814
+ full_kv_indices,
1815
+ full_q_num_blocks,
1816
+ full_q_indices,
1817
+ ]
1818
+ + list(score_mod_other_buffers)
1819
+ + list(mask_mod_other_buffers)
1820
+ )
1821
+ input_gen_fns = {
1822
+ 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks
1823
+ 9: create_indices_fake,
1824
+ 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks
1825
+ 11: create_indices_fake,
1826
+ 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks
1827
+ 13: create_indices_fake,
1828
+ 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks
1829
+ 15: create_indices_fake,
1830
+ }
1831
+
1832
+ grad_key = autotune_select_algorithm(
1833
+ "flex_attention_backward",
1834
+ choices,
1835
+ inputs_for_autotuning,
1836
+ layout_k,
1837
+ input_gen_fns=input_gen_fns,
1838
+ )
1839
+ return (
1840
+ grad_query,
1841
+ grad_key,
1842
+ grad_value,
1843
+ )
.venv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ """ Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
3
+ from typing import Any, List, Tuple
4
+
5
+ import sympy
6
+
7
+ import torch
8
+ from torch._inductor.virtualized import V
9
+
10
+ from .. import config, ir
11
+ from ..ir import FixedLayout, FlexibleLayout
12
+ from ..lowering import empty, empty_strided, lowerings
13
+ from ..runtime.runtime_utils import is_power_of_2, next_power_of_2
14
+ from ..select_algorithm import autotune_select_algorithm, TritonTemplate
15
+ from .flex_attention import (
16
+ compute_forward_block_mn,
17
+ compute_forward_inner,
18
+ compute_next_offset_func,
19
+ create_indices_fake,
20
+ create_num_blocks_fake_generator,
21
+ maybe_realize,
22
+ )
23
+
24
+
25
+ aten = torch.ops.aten
26
+ prims = torch.ops.prims
27
+
28
+
29
+ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta):
30
+ """How is this kernel parallelized?
31
+ We create a grid of (batch_size * kv_heads, SPLIT_KV, 1)
32
+ Each block is responsible for iterating over blocks of keys and values calculating
33
+ the local output for their tile of keys and values over all full length of query.
34
+ groups of SPLIT_KV blocks then combine their output to produce the final result.
35
+ """
36
+
37
+ return (batch_size * kv_heads, meta["SPLIT_KV"], 1)
38
+
39
+
40
+ flex_decoding_template = TritonTemplate(
41
+ name="flex_decoding",
42
+ grid=flex_decoding_grid,
43
+ source=r"""
44
+ {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
45
+ # Sub notation for this kernel:
46
+ # Q: Query, K: Key, V: Value
47
+ # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
48
+ # M: Number of queries, N: Number of keys/values
49
+ # QK_HEAD_DIM: The dimension of the query and key embeddings
50
+ # V_HEAD_DIM: The dimension of the value embeddings
51
+ # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
52
+ # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
53
+ # (Modifiable) Config options:
54
+ # SPLIT_KV: number of blocks K & V are split into
55
+ # TILE_KV: length of each local KV split
56
+ # BLOCK_M: block size that Q is padded along seqlen dim.
57
+ # BLOCK_N: block size of K & V along N dimension.
58
+ # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
59
+ #
60
+ # change of base out of the loop
61
+ # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
62
+ # is not masked out? If so, we can skip an extra safety check
63
+ # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
64
+ # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
65
+
66
+ # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
67
+ #
68
+ # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
69
+ # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
70
+ # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
71
+ #
72
+ #
73
+ # Output: ACC output accumulated across local KV split.
74
+
75
+ tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
76
+
77
+ # Define Q Strides
78
+ stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
79
+ stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
80
+ stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
81
+ stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
82
+ stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
83
+
84
+
85
+ Z = {{size("Q", 0)}}
86
+ HKV = {{size("Q", 1)}}
87
+ G: tl.constexpr = GQA_SHARED_HEADS
88
+ HQ = HKV * G
89
+ Q_LEN = {{size("Q", 3)}}
90
+ KV_LEN = {{size("K", 2)}}
91
+
92
+ MATMUL_PRECISION = Q.dtype.element_ty
93
+
94
+ # Make sure each split is a multiple of BLOCK_N
95
+ TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
96
+ TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
97
+ TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
98
+
99
+ off_z = tl.program_id(0) // HKV
100
+ off_hkv = tl.program_id(0) % HKV
101
+ off_t = tl.program_id(1)
102
+
103
+ q_offset = off_z * stride_qz + off_hkv * stride_qh
104
+ k_offset = off_z * stride_kz + off_hkv * stride_kh
105
+ v_offset = off_z * stride_vz + off_hkv * stride_vh
106
+
107
+ SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
108
+ SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
109
+
110
+ sparse_idx_z = off_z % SPARSE_Z
111
+ # TODO: support masks not broadcasted along the head dimension.
112
+ tl.device_assert(SPARSE_HQ == 1)
113
+ sparse_idx_h = 0
114
+
115
+ SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
116
+ SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
117
+
118
+ # initialize pointer to m and l
119
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
120
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
121
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
122
+
123
+ # initialize offsets
124
+ tl.device_assert(BLOCK_M % G == 0)
125
+ BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
126
+ off_g = tl.arange(0, G) # [G]
127
+ offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
128
+ offs_hq = offs_g + off_hkv * G
129
+ off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
130
+ offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
131
+ offs_d = tl.arange(0, QK_HEAD_DIM)
132
+ offs_vd = tl.arange(0, V_HEAD_DIM)
133
+
134
+ # KV_IDX / FULL_KV_IDX and KV_NUM_BLKS / FULL_KV_NUM_BLKS are always contiguous.
135
+ sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_h
136
+
137
+ # Calculate KV blocks that belong this CTA.
138
+ block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
139
+ block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
140
+
141
+ q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
142
+
143
+ if SAFE_M_BOUNDARY:
144
+ q = tl.load(Q + q_offset + q_range)
145
+ else:
146
+ mask = off_m[None, :, None] < Q_LEN
147
+ q = tl.load(Q + q_offset + q_range, mask)
148
+
149
+ q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM])
150
+
151
+
152
+ # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
153
+ # Apply both score_mod and mask_mod
154
+
155
+ # find first kv block we are loading and the number of blocks we are loading
156
+ kv_indices = KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
157
+ kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_hz_offset)
158
+ indices_idx = block_n_start // SPARSE_KV_MULTIPLE
159
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
160
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
161
+ # first kv block we're loading
162
+
163
+ # last valid block according to sparse mask
164
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
165
+
166
+ K_block_ptr = tl.make_block_ptr(
167
+ base=K + k_offset,
168
+ shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
169
+ strides=(stride_kk, stride_kn),
170
+ offsets=(0, off_n),
171
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
172
+ order=(0, 1)
173
+ )
174
+ V_block_ptr = tl.make_block_ptr(
175
+ base=V + v_offset,
176
+ shape=(KV_LEN, V_HEAD_DIM),
177
+ strides=(stride_vn, stride_vk),
178
+ offsets=(off_n, 0),
179
+ block_shape=(BLOCK_N, V_HEAD_DIM),
180
+ order=(1, 0)
181
+ )
182
+ offs_n = tl.arange(0, BLOCK_N) + off_n
183
+
184
+ acc, l_i, m_i = forward_inner(
185
+ {{gen_argdefs()}},
186
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
187
+ # accumulatd values
188
+ acc, l_i, m_i,
189
+ #offsets
190
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
191
+ #block sparse data
192
+ kv_indices, kv_num_blocks,
193
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
194
+ MATMUL_PRECISION,
195
+ IS_FULL_BLOCKS=False,
196
+ )
197
+
198
+
199
+ # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
200
+ # We know these blocks are guaranteed to be "full", so we don't need to
201
+ # apply mask_mod to them - only score_mod
202
+ if HAS_FULL_BLOCKS:
203
+ kv_indices = FULL_KV_IDX + sparse_hz_offset * SPARSE_KV_BLOCK_CNT
204
+ kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_hz_offset)
205
+ indices_idx = block_n_start // SPARSE_KV_MULTIPLE
206
+ off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
207
+ off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
208
+
209
+ # last valid block according to sparse mask
210
+ block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
211
+
212
+ K_block_ptr = tl.make_block_ptr(
213
+ base=K + k_offset,
214
+ shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
215
+ strides=(stride_kk, stride_kn),
216
+ offsets=(0, off_n),
217
+ block_shape=(QK_HEAD_DIM, BLOCK_N),
218
+ order=(0, 1)
219
+ )
220
+ V_block_ptr = tl.make_block_ptr(
221
+ base=V + v_offset,
222
+ shape=(KV_LEN, V_HEAD_DIM),
223
+ strides=(stride_vn, stride_vk),
224
+ offsets=(off_n, 0),
225
+ block_shape=(BLOCK_N, V_HEAD_DIM),
226
+ order=(1, 0)
227
+ )
228
+ offs_n = tl.arange(0, BLOCK_N) + off_n
229
+
230
+ acc, l_i, m_i = forward_inner(
231
+ {{gen_argdefs()}},
232
+ q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
233
+ # accumulatd values
234
+ acc, l_i, m_i,
235
+ #offsets
236
+ off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
237
+ #block sparse data
238
+ kv_indices, kv_num_blocks,
239
+ block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
240
+ MATMUL_PRECISION,
241
+ IS_FULL_BLOCKS=True,
242
+ )
243
+
244
+ m_offset = off_t * stride_mt + off_z * stride_mz
245
+ l_offset = off_t * stride_lt + off_z * stride_lz
246
+
247
+ M_block_ptr = tl.make_block_ptr(
248
+ base=M + m_offset,
249
+ shape=(G, Q_LEN), # (G, M)
250
+ strides=(stride_mh, stride_mm),
251
+ offsets=(off_hkv*G, 0),
252
+ block_shape=(G, BLOCK_M_PER_HQ),
253
+ order=(1, 0)
254
+ )
255
+ L_block_ptr = tl.make_block_ptr(
256
+ base=L + l_offset,
257
+ shape=(G, Q_LEN), # (G, M)
258
+ strides=(stride_lh, stride_lm),
259
+ offsets=(off_hkv*G, 0),
260
+ block_shape=(G, BLOCK_M_PER_HQ),
261
+ order=(1, 0)
262
+ )
263
+
264
+ # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
265
+ m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
266
+ l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
267
+ if SAFE_M_BOUNDARY:
268
+ tl.store(M_block_ptr, m_i)
269
+ tl.store(L_block_ptr, l_i)
270
+ else:
271
+ tl.store(M_block_ptr, m_i, boundary_check=(1,))
272
+ tl.store(L_block_ptr, l_i, boundary_check=(1,))
273
+
274
+ # -- store output
275
+ idx_z = off_z
276
+ idx_t = off_t
277
+ idx_hq = off_hkv*G + off_g[:, None, None]
278
+ idx_m = off_m[None, :, None]
279
+ idx_d = offs_vd[None, None, :]
280
+
281
+ mask = (idx_m < Q_LEN)
282
+ acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
283
+ {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
284
+ """
285
+ + compute_forward_inner
286
+ + compute_next_offset_func
287
+ + compute_forward_block_mn,
288
+ )
289
+
290
+
291
+ def get_split_k(B: int, H: int, Mk: int, SM: int = 128) -> int:
292
+ """Heuristic for the number of splits from xformer"""
293
+ bh = max(B * H, 1) # NOTE: Handle B*h=0 case
294
+ split_k = SM // bh # Each SM should at least get one block.
295
+ split_k = max(split_k, 1)
296
+
297
+ return split_k
298
+
299
+
300
+ def _get_decoding_default_config(key) -> Tuple[int, int, int]:
301
+ dtype = key.get_dtype()
302
+ head_dim = key.get_size()[-1]
303
+ sm_version = torch.cuda.get_device_capability()
304
+ default_config = (64, 2, 1)
305
+ if sm_version >= (9, 0):
306
+ if head_dim > 128 and dtype == torch.float32:
307
+ return default_config
308
+ return (64, 2, 3)
309
+ return default_config
310
+
311
+
312
+ def create_flex_decoding_kernel(*args, **kwargs):
313
+ (
314
+ query,
315
+ key,
316
+ value,
317
+ block_mask,
318
+ scale,
319
+ kernel_options,
320
+ score_mod_subgraph,
321
+ mask_mod_subgraph,
322
+ score_mod_other_buffers,
323
+ mask_mod_other_buffers,
324
+ ) = args
325
+ (
326
+ kv_num_blocks,
327
+ kv_indices,
328
+ full_kv_num_blocks, # full_kv_num_blocks,
329
+ full_kv_indices, # full_kv_indices,
330
+ _, # q_num_blocks
331
+ _, # q_indices
332
+ _, # full_q_num_blocks,
333
+ _, # full_q_indices,
334
+ SPARSE_KV_BLOCK_SIZE,
335
+ _, # SPARSE_Q_BLOCK_SIZE,
336
+ _,
337
+ ) = block_mask
338
+
339
+ Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
340
+ Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
341
+ assert Bq == Bkv, "Batch dimension must match"
342
+ B = Bq
343
+ kernel_options = dict(kernel_options)
344
+
345
+ # TODO: Fix flex decoding non-divisible case!
346
+ if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
347
+ kernel_options.setdefault("IS_DIVISIBLE", False)
348
+ else:
349
+ kernel_options.setdefault("IS_DIVISIBLE", True)
350
+
351
+ # Calculate GQA head sharing
352
+ gqa_shared_heads = Hq // Hkv
353
+ if not is_power_of_2(gqa_shared_heads):
354
+ raise ValueError(
355
+ "Number of shared query heads sharing the same KV head must be power of 2. "
356
+ )
357
+ kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
358
+
359
+ # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod
360
+ has_full_blocks = full_kv_num_blocks is not None
361
+ kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
362
+ if not has_full_blocks:
363
+ # Create a plackeholder full block list in case it is empty
364
+ full_kv_num_blocks, full_kv_indices = (
365
+ empty(0, device=query.get_device()) for _ in range(2)
366
+ )
367
+
368
+ (
369
+ query,
370
+ key,
371
+ value,
372
+ kv_num_blocks,
373
+ kv_indices,
374
+ full_kv_num_blocks,
375
+ full_kv_indices,
376
+ ) = maybe_realize(
377
+ [
378
+ query,
379
+ key,
380
+ value,
381
+ kv_num_blocks,
382
+ kv_indices,
383
+ full_kv_num_blocks,
384
+ full_kv_indices,
385
+ ]
386
+ )
387
+
388
+ choices: List[Any] = []
389
+ configs: List[Tuple[int, int, int]] = []
390
+ configs.append(_get_decoding_default_config(key))
391
+ # Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
392
+ if config.max_autotune:
393
+ configs += [
394
+ (64, 2, 2),
395
+ (32, 2, 3),
396
+ (128, 2, 3),
397
+ ]
398
+ # TODO: fix autotuning.
399
+
400
+ kernel_options.setdefault("SM_SCALE", scale)
401
+ kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv))
402
+ MAX_SPLIT_KV = kernel_options["SPLIT_KV"]
403
+
404
+ # create config dependent intermediate buffers
405
+ buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim]
406
+ buf_ML_shape = buf_ACC_shape[:-1]
407
+ buf_M = empty_strided(
408
+ buf_ML_shape,
409
+ None,
410
+ dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype
411
+ device=query.get_device(),
412
+ )
413
+ buf_L = empty_strided(
414
+ buf_ML_shape,
415
+ None,
416
+ dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype
417
+ device=query.get_device(),
418
+ )
419
+
420
+ layout_acc = FixedLayout(
421
+ query.get_device(),
422
+ torch.float32,
423
+ buf_ACC_shape,
424
+ FlexibleLayout.contiguous_strides(buf_ACC_shape),
425
+ )
426
+
427
+ kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim)
428
+ kernel_options.setdefault("V_HEAD_DIM", v_head_dim)
429
+
430
+ kernel_options.setdefault(
431
+ "BLOCK_M",
432
+ (
433
+ # m
434
+ # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0))
435
+ # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin
436
+ max(
437
+ next_power_of_2(
438
+ V.graph.sizevars.size_hint(
439
+ seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
440
+ )
441
+ * gqa_shared_heads
442
+ ),
443
+ 16,
444
+ )
445
+ ),
446
+ )
447
+
448
+ query = ir.ExternKernel.realize_input(query)
449
+ stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride()
450
+
451
+ # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D]
452
+ gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim)
453
+ gqa_query_stride = (
454
+ stride_b,
455
+ stride_hq * gqa_shared_heads,
456
+ stride_hq,
457
+ stride_seq_len_q,
458
+ stride_qk_head_dim,
459
+ )
460
+ query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride)
461
+
462
+ V.graph.sizevars.guard_leq(
463
+ seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"])
464
+ )
465
+
466
+ kernel_options.setdefault(
467
+ "SAFE_M_BOUNDARY",
468
+ ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0,
469
+ )
470
+ # TODO: This feels sketchy
471
+ kernel_options.setdefault("SAFE_N_BOUNDARY", True)
472
+
473
+ # Note, we don't need to pass in the captured buffers explicitly
474
+ # because they're implicitly added by the score_mod function
475
+ # We do need to explicitly pass it in for autotuning though.
476
+ for BLOCK_N, num_warps, num_stages in configs:
477
+ if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0:
478
+ continue
479
+
480
+ # Performance tuning
481
+ kernel_options.setdefault("BLOCK_N", BLOCK_N)
482
+ kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
483
+
484
+ # Work around https://github.com/pytorch/pytorch/issues/129625
485
+ if num_stages == 2:
486
+ continue
487
+ flex_decoding_template.maybe_append_choice(
488
+ choices=choices,
489
+ input_nodes=[
490
+ query,
491
+ key,
492
+ value,
493
+ buf_M,
494
+ buf_L,
495
+ kv_num_blocks,
496
+ kv_indices,
497
+ full_kv_num_blocks,
498
+ full_kv_indices,
499
+ ],
500
+ layout=layout_acc,
501
+ subgraphs=[
502
+ score_mod_subgraph,
503
+ mask_mod_subgraph,
504
+ ],
505
+ mutated_inputs=[buf_M, buf_L],
506
+ num_stages=num_stages,
507
+ num_warps=num_warps,
508
+ call_sizes=query.get_size(),
509
+ **kernel_options,
510
+ )
511
+
512
+ inputs_for_flex_decoding = (
513
+ [
514
+ query,
515
+ key,
516
+ value,
517
+ buf_M,
518
+ buf_L,
519
+ kv_num_blocks,
520
+ kv_indices,
521
+ full_kv_num_blocks,
522
+ full_kv_indices,
523
+ ]
524
+ + list(score_mod_other_buffers)
525
+ + list(mask_mod_other_buffers)
526
+ )
527
+
528
+ input_gen_fns = {
529
+ 5: create_num_blocks_fake_generator(kv_indices),
530
+ 6: create_indices_fake,
531
+ 7: create_num_blocks_fake_generator(full_kv_indices),
532
+ 8: create_indices_fake,
533
+ }
534
+
535
+ buf_ACC = autotune_select_algorithm(
536
+ "flex_decoding",
537
+ choices,
538
+ inputs_for_flex_decoding,
539
+ layout_acc,
540
+ input_gen_fns=input_gen_fns,
541
+ )
542
+
543
+ # Reduction
544
+
545
+ g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0]
546
+ # See [Note] Handle fully masked out rows:
547
+ # g_M Is the global max among split kv blocks.
548
+ masked_rows = lowerings[aten.eq](g_M, -float("inf"))
549
+ adj_M = lowerings[aten.sub](buf_M, g_M)
550
+ adj_M = lowerings[aten.where](masked_rows, 0, adj_M)
551
+ alpha = lowerings[aten.exp2](adj_M)
552
+
553
+ buf_L = lowerings[aten.mul](buf_L, alpha)
554
+ g_L = lowerings[aten.sum](buf_L, axis=1)
555
+ masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1)
556
+ g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L)
557
+ logsumexp = lowerings[aten.log2](g_L)
558
+ logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1))
559
+
560
+ alpha_unseq = lowerings[aten.unsqueeze](alpha, 4)
561
+ buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq)
562
+ output = lowerings[aten.sum](buf_ACC, axis=1)
563
+ L_unseq = lowerings[aten.unsqueeze](g_L, 3)
564
+ output = lowerings[aten.div](output, L_unseq)
565
+ output = lowerings[prims.convert_element_type](output, query.get_dtype())
566
+
567
+ return (
568
+ output,
569
+ logsumexp,
570
+ )
.venv/Lib/site-packages/torch/_inductor/kernel/mm.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
8
+ from torch._inductor.autoheuristic.autoheuristic_utils import (
9
+ AHContext,
10
+ context_add_strides,
11
+ context_add_using_tf32,
12
+ get_mixedmm_precondition,
13
+ mixed_mm_operations,
14
+ mm_operations,
15
+ )
16
+ from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
17
+ from torch._inductor.virtualized import V
18
+
19
+ from .. import config as inductor_config
20
+ from ..codegen.common import BackendFeature
21
+ from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
22
+ from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
23
+ from ..codegen.wrapper import WrapperCodeGen
24
+ from ..ir import FlexibleLayout, is_triton
25
+ from ..lowering import register_lowering
26
+ from ..select_algorithm import (
27
+ autotune_select_algorithm,
28
+ ExternKernelChoice,
29
+ NoValidChoicesError,
30
+ TritonTemplate,
31
+ )
32
+ from ..utils import (
33
+ get_gpu_shared_memory,
34
+ use_aten_gemm_kernels,
35
+ use_ck_template,
36
+ use_cpp_packed_gemm_template,
37
+ use_cutlass_template,
38
+ use_max_autotune,
39
+ use_triton_template,
40
+ )
41
+ from .mm_common import (
42
+ addmm_epilogue,
43
+ extra_mm_configs,
44
+ int8_mm_configs,
45
+ mixed_mm_configs,
46
+ mm_args,
47
+ mm_configs,
48
+ mm_grid,
49
+ mm_options,
50
+ triton_config,
51
+ )
52
+
53
+
54
+ log = logging.getLogger(__name__)
55
+ aten = torch.ops.aten
56
+
57
+ mm_template = TritonTemplate(
58
+ name="mm",
59
+ grid=mm_grid,
60
+ source=r"""
61
+ {{def_kernel("A", "B")}}
62
+ M = {{size("A", 0)}}
63
+ N = {{size("B", 1)}}
64
+ K = {{size("A", 1)}}
65
+ if M * N == 0:
66
+ # early exit due to zero-size input(s)
67
+ return
68
+ stride_am = {{stride("A", 0)}}
69
+ stride_ak = {{stride("A", 1)}}
70
+ stride_bk = {{stride("B", 0)}}
71
+ stride_bn = {{stride("B", 1)}}
72
+
73
+ # based on triton.ops.matmul
74
+ pid = tl.program_id(0)
75
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
76
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
77
+
78
+ # re-order program ID for better L2 performance
79
+ width = GROUP_M * grid_n
80
+ group_id = pid // width
81
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
82
+ pid_m = group_id * GROUP_M + (pid % group_size)
83
+ pid_n = (pid % width) // (group_size)
84
+
85
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
86
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
87
+ if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
88
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
89
+ else:
90
+ ram = rm % M
91
+ if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
92
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
93
+ else:
94
+ rbn = rn % N
95
+ rk = tl.arange(0, BLOCK_K)
96
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
97
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
98
+
99
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
100
+ for k in range(K, 0, -BLOCK_K):
101
+ if EVEN_K:
102
+ a = tl.load(A)
103
+ b = tl.load(B)
104
+ else:
105
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
106
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
107
+ if B_PROLOGUE_CAST_TYPE is not None:
108
+ b = b.to(B_PROLOGUE_CAST_TYPE)
109
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
110
+ A += BLOCK_K * stride_ak
111
+ B += BLOCK_K * stride_bk
112
+
113
+ # rematerialize rm and rn to save registers
114
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
115
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
116
+ idx_m = rm[:, None]
117
+ idx_n = rn[None, :]
118
+ mask = (idx_m < M) & (idx_n < N)
119
+
120
+ # inductor generates a suffix
121
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
122
+ """,
123
+ )
124
+
125
+ aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
126
+
127
+ aten_addmm = ExternKernelChoice(
128
+ torch.addmm, "at::addmm_out", op_overload=aten.addmm.default
129
+ )
130
+
131
+ aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm")
132
+
133
+ aten__sparse_semi_structured_mm = ExternKernelChoice(
134
+ torch._sparse_semi_structured_mm,
135
+ "at::_sparse_semi_structured_mm",
136
+ has_out_variant=False,
137
+ )
138
+
139
+
140
+ def _is_int8_mat(mat):
141
+ return mat.get_dtype() in (torch.int8, torch.uint8)
142
+
143
+
144
+ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
145
+ """
146
+ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
147
+ kernel under the hood. There are a few shapes where this is slower,
148
+ but they are rare.
149
+ """
150
+ if inp.stride(0) == 0 or inp.size(0) == 1:
151
+ return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
152
+ return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
153
+
154
+
155
+ aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
156
+
157
+
158
+ @register_lowering(aten.mm, type_promotion_kind=None)
159
+ def tuned_mm(mat1, mat2, *, layout=None):
160
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
161
+ name = "mm"
162
+
163
+ aten_layout = layout
164
+ if not use_max_autotune():
165
+ aten_layout = FlexibleLayout(
166
+ device=layout.device, dtype=layout.dtype, size=layout.size
167
+ )
168
+
169
+ # options to tune from
170
+ choices = (
171
+ [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else []
172
+ )
173
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
174
+ if is_nonzero and use_triton_template(layout):
175
+ for config in mm_configs(m, n, k):
176
+ mm_template.maybe_append_choice(
177
+ choices,
178
+ input_nodes=(mat1, mat2),
179
+ layout=layout,
180
+ **mm_options(config, m, n, k, layout),
181
+ )
182
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
183
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
184
+
185
+ if is_nonzero and use_ck_template(layout, m, n, k):
186
+ CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
187
+
188
+ if use_cpp_packed_gemm_template(layout, mat1, mat2):
189
+ CppPackedGemmTemplate.add_choices(
190
+ choices,
191
+ layout,
192
+ [mat1, mat2],
193
+ )
194
+
195
+ input_nodes = [mat1, mat2]
196
+ if (
197
+ is_nonzero
198
+ and use_triton_template(layout)
199
+ and torch._inductor.config.run_autoheuristic(name)
200
+ and is_triton(mat1)
201
+ ):
202
+ always_included = []
203
+ if use_aten_gemm_kernels():
204
+ always_included.append("extern_mm")
205
+ num_choices_before_extra_configs = len(choices)
206
+ for config in extra_mm_configs(m, n, k):
207
+ mm_template.maybe_append_choice(
208
+ choices,
209
+ input_nodes=(mat1, mat2),
210
+ layout=layout,
211
+ **mm_options(config, m, n, k, layout),
212
+ )
213
+
214
+ # using AutoHeuristic for ranking
215
+ ah_choices = mm_autoheuristic(
216
+ mat1,
217
+ mat2,
218
+ m,
219
+ n,
220
+ k,
221
+ choices,
222
+ name,
223
+ input_nodes,
224
+ mm_operations(),
225
+ None,
226
+ top_k=10,
227
+ always_included=always_included,
228
+ )
229
+ if not torch._inductor.config.collect_autoheuristic(name):
230
+ # if we are collecting data, we do not want to modify choices
231
+ if ah_choices is not None and len(ah_choices) > 0:
232
+ # the order in which autoheuristic returns choices is not the same as
233
+ # as the order of choices, which affects things like epilogue fusion.
234
+ # once epilogue fusion benchmarks choices in sorted order, I think we can
235
+ # just use the order returned by autoheuristic
236
+ choices = [choice for choice in choices if choice in ah_choices]
237
+ else:
238
+ choices = choices[:num_choices_before_extra_configs]
239
+
240
+ if (
241
+ len(choices) == 0
242
+ and not use_aten_gemm_kernels()
243
+ and inductor_config.autotune_fallback_to_aten
244
+ ):
245
+ log.warning("No choices for GEMM, using ATen backend as fallback")
246
+ return aten_mm.bind((mat1, mat2), aten_layout).output_node()
247
+
248
+ try:
249
+ return autotune_select_algorithm(name, choices, [mat1, mat2], layout)
250
+ except NoValidChoicesError:
251
+ if not inductor_config.autotune_fallback_to_aten:
252
+ raise
253
+ log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
254
+ return aten_mm.bind((mat1, mat2), aten_layout).output_node()
255
+
256
+
257
+ def _is_static_problem(inputs_tensors, layout):
258
+ # checks whether all input tensors and the output layout
259
+ # have a static shape by attempting to convert the dimensions
260
+ # to int
261
+ static_shape = True
262
+ static_size = WrapperCodeGen.statically_known_list_of_ints_or_none(layout.size)
263
+ if static_size is None:
264
+ nonzero = True
265
+ for s in layout.size:
266
+ sz = WrapperCodeGen.statically_known_int_or_none(s)
267
+ if sz is not None and sz == 0:
268
+ nonzero = False
269
+ break
270
+ return False, nonzero
271
+ numel = 1
272
+ for dim in static_size:
273
+ numel *= dim
274
+ nonzero = numel > 0
275
+ return static_shape, nonzero
276
+
277
+
278
+ @register_lowering(aten._int_mm, type_promotion_kind=None)
279
+ def tuned_int_mm(mat1, mat2, *, layout=None):
280
+ m, n, k, layout, mat1, mat2 = mm_args(
281
+ mat1, mat2, layout=layout, out_dtype=torch.int32
282
+ )
283
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
284
+ use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k)
285
+
286
+ choices = (
287
+ [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
288
+ )
289
+
290
+ # TODO: Re-enable eager mode implementation once cuBLAS is fixed
291
+ if use_cutlass or use_triton_template(layout, enable_int32=True):
292
+ choices = []
293
+
294
+ if use_cutlass:
295
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
296
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
297
+ )
298
+ if is_nonzero and use_triton_template(layout, enable_int32=True):
299
+ for config in int8_mm_configs(m, n, k):
300
+ mm_template.maybe_append_choice(
301
+ choices,
302
+ input_nodes=(mat1, mat2),
303
+ layout=layout,
304
+ **mm_options(config, m, n, k, layout),
305
+ )
306
+ if len(choices) == 0:
307
+ log.warning(
308
+ "No choices for integer GEMM avaialbe using configured backends, using ATen backend as fallback"
309
+ )
310
+ choices = [aten__int_mm.bind((mat1, mat2), layout)]
311
+
312
+ try:
313
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
314
+ except NoValidChoicesError:
315
+ if not inductor_config.autotune_fallback_to_aten:
316
+ raise
317
+ log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
318
+ choices = [aten__int_mm.bind((mat1, mat2), layout)]
319
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout)
320
+
321
+
322
+ @register_lowering(aten.addmm, type_promotion_kind=None)
323
+ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
324
+ ordered_kwargs_for_cpp_kernel = ("beta", "alpha")
325
+ m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
326
+ static_shape, is_nonzero = _is_static_problem([inp, mat1, mat2], layout)
327
+ if (not is_nonzero) or (not use_max_autotune()):
328
+ # Use a FlexibleLayout if we are not autotuning.
329
+ # This allows padding strides for the output.
330
+ from torch._inductor.ir import FixedLayout, FlexibleLayout
331
+
332
+ if isinstance(layout, FixedLayout):
333
+ layout = FlexibleLayout(
334
+ device=layout.device, dtype=layout.dtype, size=layout.size
335
+ )
336
+ choices = (
337
+ [
338
+ aten_addmm.bind(
339
+ (inp, mat1, mat2),
340
+ layout,
341
+ alpha=alpha,
342
+ beta=beta,
343
+ )
344
+ ]
345
+ if use_aten_gemm_kernels()
346
+ else []
347
+ )
348
+ return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout)
349
+
350
+ choices = (
351
+ [
352
+ aten_addmm.bind(
353
+ (inp_expanded, mat1, mat2),
354
+ layout,
355
+ alpha=alpha,
356
+ beta=beta,
357
+ )
358
+ ]
359
+ if use_aten_gemm_kernels()
360
+ else []
361
+ )
362
+
363
+ if (
364
+ use_aten_gemm_kernels()
365
+ and inp_expanded.get_stride()[0] == 0
366
+ and inp_expanded.get_device().type == "cuda"
367
+ and inductor_config.triton.autotune_cublasLt
368
+ ):
369
+ # unexpand inp to make sure fused addmm from cublasLt is used
370
+ choices.insert(
371
+ 0,
372
+ aten_bias_addmm.bind(
373
+ (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
374
+ ),
375
+ )
376
+
377
+ if is_nonzero and use_triton_template(layout):
378
+ for config in mm_configs(m, n, k):
379
+ mm_template.maybe_append_choice(
380
+ choices,
381
+ input_nodes=(inp_expanded, mat1, mat2),
382
+ layout=layout,
383
+ **mm_options(config, m, n, k, layout),
384
+ prefix_args=1,
385
+ epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
386
+ )
387
+
388
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
389
+ # Filter out a known cause of CUDA illegal memory access errors
390
+ # broadcasting on the last dim of the bias term seems not to be working
391
+ # in the linear GEMM epilogue used by addmm.
392
+ if (
393
+ WrapperCodeGen.statically_known_int_or_none(inp_expanded.layout.stride[-1])
394
+ != 0
395
+ ):
396
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
397
+ choices,
398
+ layout,
399
+ [mat1, mat2, inp_expanded],
400
+ alpha=alpha,
401
+ beta=beta,
402
+ )
403
+
404
+ if is_nonzero and use_ck_template(layout, m, n, k):
405
+ CKGemmTemplate.add_ck_gemm_choices(
406
+ choices,
407
+ layout,
408
+ [mat1, mat2, inp_expanded],
409
+ alpha=alpha,
410
+ beta=beta,
411
+ )
412
+
413
+ if use_cpp_packed_gemm_template(layout, mat1, mat2):
414
+ CppPackedGemmTemplate.add_choices(
415
+ choices,
416
+ layout,
417
+ [inp_expanded, mat1, mat2],
418
+ alpha=alpha,
419
+ beta=beta,
420
+ has_bias=True,
421
+ )
422
+
423
+ add_aten_fallback = False
424
+ if len(choices) == 0:
425
+ log.warning("No choices for GEMM, using ATen backend as fallback")
426
+ add_aten_fallback = True
427
+
428
+ if add_aten_fallback:
429
+ choices.append(
430
+ aten_addmm.bind(
431
+ (inp_expanded, mat1, mat2),
432
+ layout,
433
+ ordered_kwargs_for_cpp_kernel,
434
+ alpha=alpha,
435
+ beta=beta,
436
+ )
437
+ )
438
+
439
+ if (
440
+ inp_expanded.get_stride()[0] == 0
441
+ and inp_expanded.get_device().type == "cuda"
442
+ and inductor_config.triton.autotune_cublasLt
443
+ ):
444
+ # unexpand inp to make sure fused addmm from cublasLt is used
445
+ choices.insert(
446
+ 0,
447
+ aten_bias_addmm.bind(
448
+ (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
449
+ ),
450
+ )
451
+ try:
452
+ return autotune_select_algorithm(
453
+ "addmm", choices, [inp_expanded, mat1, mat2], layout
454
+ )
455
+ except NoValidChoicesError:
456
+ if not inductor_config.autotune_fallback_to_aten:
457
+ raise
458
+ log.warning("All choices for GEMM were invalid, using ATen backend as fallback")
459
+ fallback_choice = aten_addmm.bind(
460
+ (inp, mat1, mat2),
461
+ layout,
462
+ ordered_kwargs_for_cpp_kernel,
463
+ alpha=alpha,
464
+ beta=beta,
465
+ )
466
+ return fallback_choice.output_node()
467
+
468
+
469
+ @register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
470
+ def tuned_sparse_semi_structured_mm(
471
+ mat1, mat1_meta, mat2, *, out_dtype=None, layout=None
472
+ ):
473
+ from torch._inductor.select_algorithm import realize_inputs
474
+
475
+ mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2)
476
+ m1, k1 = mat1.get_size()
477
+ m2, _ = mat1_meta.get_size()
478
+ k2, n = mat2.get_size()
479
+ m = V.graph.sizevars.guard_equals(m1, m2)
480
+ k = V.graph.sizevars.guard_equals(2 * k1, k2)
481
+
482
+ if layout is None:
483
+ from torch._inductor.ir import FixedLayout
484
+
485
+ layout = FixedLayout(
486
+ mat2.get_device(),
487
+ out_dtype if out_dtype else mat2.get_dtype(),
488
+ [m, n],
489
+ [n, 1],
490
+ )
491
+ else:
492
+ assert out_dtype is None, "out_dtype is ignored if layout is specified."
493
+
494
+ choices = (
495
+ [
496
+ aten__sparse_semi_structured_mm.bind(
497
+ (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype
498
+ )
499
+ ]
500
+ if use_aten_gemm_kernels()
501
+ else []
502
+ )
503
+
504
+ if m * n != 0 and use_cutlass_template(layout, m, n, k):
505
+ CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
506
+ choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True
507
+ )
508
+
509
+ return autotune_select_algorithm(
510
+ "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout
511
+ )
512
+
513
+
514
+ def fallback_mixed_mm(mat1, mat2, *, out):
515
+ return torch.mm(mat1, mat2.to(mat1.dtype), out=out)
516
+
517
+
518
+ aten_fallback_mixed_mm = ExternKernelChoice(fallback_mixed_mm, None)
519
+
520
+
521
+ @functools.lru_cache(None)
522
+ def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool:
523
+ props = torch.cuda.get_device_properties(index or 0)
524
+ return props.major <= 7
525
+
526
+
527
+ def dims_are_int(dims):
528
+ return all(isinstance(dim, int) for dim in dims)
529
+
530
+
531
+ def try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout):
532
+ m, n, k = get_size_hints(mat1, mat2, m, n, k)
533
+ if not dims_are_int([m, n, k]):
534
+ return None
535
+
536
+ if mat1.dtype != torch.float16:
537
+ return None
538
+
539
+ # only use heuristic if we are running on an A100
540
+ # torch.cuda.get_device_capability() >= (8, 0) returns true for A10G
541
+ # which does not have enough shared memory for one of the configs
542
+ if (
543
+ not torch.cuda.get_device_capability() >= (8, 0)
544
+ ) or get_gpu_shared_memory() != 166912:
545
+ return None
546
+
547
+ if m == 1 and (n % 16 != 0 or k % 16 != 0):
548
+ return None
549
+
550
+ if m <= 16 and n >= 4096 and k >= 4096:
551
+ return triton_config(
552
+ BLOCK_M=16,
553
+ BLOCK_N=64,
554
+ BLOCK_K=128,
555
+ num_stages=5,
556
+ num_warps=4,
557
+ )
558
+ elif m > 16 and m <= 32 and n >= 4096 and k >= 4096:
559
+ return triton_config(
560
+ BLOCK_M=32,
561
+ BLOCK_N=32,
562
+ BLOCK_K=128,
563
+ num_stages=5,
564
+ num_warps=4,
565
+ )
566
+ elif m > 32 and m <= 64 and n >= 4096 and k >= 4096:
567
+ return triton_config(
568
+ BLOCK_M=64,
569
+ BLOCK_N=32,
570
+ BLOCK_K=128,
571
+ num_stages=5,
572
+ num_warps=4,
573
+ )
574
+ return None
575
+
576
+
577
+ def mm_autoheuristic(
578
+ mat1,
579
+ mat2,
580
+ m,
581
+ n,
582
+ k,
583
+ choices,
584
+ name,
585
+ input_nodes,
586
+ ops,
587
+ precondition,
588
+ top_k: Optional[int] = None,
589
+ always_included=None,
590
+ ):
591
+ m, n, k = get_size_hints(mat1, mat2, m, n, k)
592
+ if not dims_are_int([m, n, k]):
593
+ return None
594
+ mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2)
595
+
596
+ def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride):
597
+ context = AHContext()
598
+ context.add_feature("m", m)
599
+ context.add_feature("k", k)
600
+ context.add_feature("n", n)
601
+ context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True)
602
+ context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True)
603
+ context_add_strides(context, "mat1", mat1_stride)
604
+ context_add_strides(context, "mat2", mat2_stride)
605
+ context.add_feature(
606
+ "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True
607
+ )
608
+ context.add_feature(
609
+ "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True
610
+ )
611
+ if name == "mm":
612
+ # for mixed_mm, we only consider fp16
613
+ context_add_using_tf32(context, mat1.layout.dtype)
614
+ return context
615
+
616
+ def fallback():
617
+ return None
618
+
619
+ context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride)
620
+ autoheuristic = AutoHeuristicSelectAlgorithm(
621
+ fallback=fallback,
622
+ choices=choices,
623
+ input_nodes=input_nodes,
624
+ context=context,
625
+ name=name,
626
+ augment_context=ops,
627
+ precondition=precondition,
628
+ )
629
+
630
+ if top_k is not None:
631
+ # TODO: is there a cleaner way to ensure aten.mm is always included?
632
+ return autoheuristic.get_top_k_choices_caller(
633
+ top_k, always_included=always_included
634
+ )
635
+
636
+ return autoheuristic.get_choice_caller()
637
+
638
+
639
+ def get_size_hints(mat1, mat2, m, n, k):
640
+ if not isinstance(m, int) or not isinstance(k, int):
641
+ (m, k) = V.graph.sizevars.size_hints(
642
+ mat1.get_size(),
643
+ fallback=torch._inductor.config.unbacked_symint_fallback,
644
+ )
645
+
646
+ if not isinstance(n, int) or not isinstance(k, int):
647
+ (k, n) = V.graph.sizevars.size_hints(
648
+ mat2.get_size(),
649
+ fallback=torch._inductor.config.unbacked_symint_fallback,
650
+ )
651
+ return m, n, k
652
+
653
+
654
+ def get_size_hints_strides(mat1, mat2):
655
+ mat1_stride = mat1.layout.stride
656
+ mat2_stride = mat2.layout.stride
657
+ strides = [mat1_stride, mat2_stride]
658
+ strides_hints = []
659
+ for stride in strides:
660
+ if not isinstance(stride, int):
661
+ stride = V.graph.sizevars.size_hints(
662
+ stride,
663
+ fallback=torch._inductor.config.unbacked_symint_fallback,
664
+ )
665
+ strides_hints.append(stride)
666
+ return strides_hints[0], strides_hints[1]
667
+
668
+
669
+ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
670
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None)
671
+ static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
672
+
673
+ fallback = aten_fallback_mixed_mm.bind((mat1, mat2), layout)
674
+
675
+ choices = [fallback]
676
+
677
+ # can't use triton kernel unless one of these is true or if running on v100 (numerical issues)
678
+ skip_triton = (
679
+ (
680
+ mat1.layout.dtype != torch.float32
681
+ and not (mat2.layout.is_contiguous() or mat2.layout.is_transposed())
682
+ )
683
+ or _is_sm7x_or_older_gpu(layout.device.index)
684
+ or inductor_config.mixed_mm_choice == "aten"
685
+ or not V.graph.has_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
686
+ or (
687
+ mat1.layout.dtype == torch.float32 and torch.backends.cuda.matmul.allow_tf32
688
+ )
689
+ or (mat1.layout.dtype == torch.bfloat16 and mat2.layout.dtype == torch.uint8)
690
+ )
691
+
692
+ if inductor_config.mixed_mm_choice == "triton":
693
+ choices = []
694
+
695
+ if not skip_triton:
696
+ b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
697
+ if static_shape and inductor_config.mixed_mm_choice == "heuristic":
698
+ choices = []
699
+ config = try_heuristic(m, n, k, choices, mat1, mat2, mat2_dtype, layout)
700
+ if config is not None:
701
+ mm_template.maybe_append_choice(
702
+ choices,
703
+ input_nodes=(mat1, mat2),
704
+ layout=layout,
705
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
706
+ )
707
+ choices.append(fallback)
708
+
709
+ has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
710
+ for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
711
+ mm_template.maybe_append_choice(
712
+ choices,
713
+ input_nodes=(mat1, mat2),
714
+ layout=layout,
715
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
716
+ )
717
+
718
+ if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
719
+ CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
720
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
721
+ )
722
+ CUTLASS2xGemmTemplate.add_cutlass_gemm_choices(
723
+ choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
724
+ )
725
+
726
+ if skip_triton and not choices:
727
+ choices = [fallback]
728
+
729
+ name = "mixed_mm"
730
+ input_nodes = [mat1, mat2]
731
+ if torch._inductor.config.run_autoheuristic(name):
732
+ choice = mm_autoheuristic(
733
+ mat1,
734
+ mat2,
735
+ m,
736
+ n,
737
+ k,
738
+ choices,
739
+ name,
740
+ input_nodes,
741
+ mixed_mm_operations(),
742
+ get_mixedmm_precondition,
743
+ )
744
+ if (
745
+ not skip_triton
746
+ and inductor_config.mixed_mm_choice == "heuristic"
747
+ and choice is not None
748
+ ):
749
+ choices.insert(0, choice)
750
+ return autotune_select_algorithm(name, choices, input_nodes, layout)
751
+
752
+
753
+ # This op is a special case of the int_mm op which we use based on the pattern
754
+ # _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent
755
+ # realization of the int32 _int_mm output by forcing fusion with the mul op.
756
+ # This is only used when config.force_fuse_int_mm_with_mul = True
757
+ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
758
+ out_dtype = (
759
+ torch.promote_types(mat3.get_dtype(), torch.int32)
760
+ if out_dtype is None
761
+ else out_dtype
762
+ )
763
+ m, n, k, layout, mat1, mat2, mat3 = mm_args(
764
+ mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
765
+ )
766
+ choices: List[Dict[Any, Any]] = []
767
+ for config in int8_mm_configs(m, n, k):
768
+ mm_template.maybe_append_choice(
769
+ choices,
770
+ input_nodes=(mat1, mat2, mat3),
771
+ layout=layout,
772
+ **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32"),
773
+ suffix_args=1,
774
+ epilogue_fn=V.ops.mul,
775
+ )
776
+ return autotune_select_algorithm("int_mm", choices, [mat1, mat2, mat3], layout)
.venv/Lib/site-packages/torch/_inductor/kernel/mm_common.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+ import itertools
4
+ import logging
5
+ from typing import cast, List, Tuple
6
+
7
+ import sympy
8
+
9
+ import torch
10
+ from torch._inductor.select_algorithm import realize_inputs
11
+ from torch._inductor.virtualized import V
12
+
13
+ from .. import config as inductor_config
14
+ from ..runtime.runtime_utils import next_power_of_2
15
+ from ..utils import ceildiv as cdiv
16
+
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ def triton_config(num_stages, num_warps, **kwargs):
22
+ from triton import Config
23
+
24
+ return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
25
+
26
+
27
+ def filtered_configs(
28
+ m: int,
29
+ n: int,
30
+ k: int,
31
+ configs: List[Tuple[int, int, int, int, int]],
32
+ has_int8_tensor=False,
33
+ ):
34
+ """Heuristic to shrink configs when they are bigger than the input size"""
35
+
36
+ min_block_size = 16
37
+ # block_k=16 seems to be causing issues
38
+ # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
39
+ min_block_size_k = 32 if has_int8_tensor else 16
40
+ m = max(
41
+ next_power_of_2(
42
+ V.graph.sizevars.size_hint(
43
+ m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
44
+ )
45
+ ),
46
+ min_block_size,
47
+ )
48
+ n = max(
49
+ next_power_of_2(
50
+ V.graph.sizevars.size_hint(
51
+ n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
52
+ )
53
+ ),
54
+ min_block_size,
55
+ )
56
+ k = max(
57
+ next_power_of_2(
58
+ V.graph.sizevars.size_hint(
59
+ k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
60
+ )
61
+ ),
62
+ min_block_size_k,
63
+ )
64
+ used = set()
65
+ for block_m, block_n, block_k, num_stages, num_warps in configs:
66
+ # shrink configs for small sizes
67
+ block_m = max(min(block_m, m), min_block_size)
68
+ block_n = max(min(block_n, n), min_block_size)
69
+ block_k = max(min(block_k, k), min_block_size_k)
70
+ # each warp computes 16x16 tile = 256
71
+ num_warps = min(num_warps, block_m * block_n // 256)
72
+ if torch.version.hip:
73
+ for matrix_instr_nonkdim in [0, 16]:
74
+ if matrix_instr_nonkdim != 0 and (
75
+ block_m % matrix_instr_nonkdim != 0
76
+ or block_n % matrix_instr_nonkdim != 0
77
+ ):
78
+ # block_m and block_n must be a multiple of matrix_instr_nonkdim
79
+ continue
80
+ if (
81
+ block_m,
82
+ block_n,
83
+ block_k,
84
+ num_stages,
85
+ num_warps,
86
+ matrix_instr_nonkdim,
87
+ ) not in used:
88
+ used.add(
89
+ (
90
+ block_m,
91
+ block_n,
92
+ block_k,
93
+ num_stages,
94
+ num_warps,
95
+ matrix_instr_nonkdim,
96
+ )
97
+ )
98
+ yield triton_config(
99
+ BLOCK_M=block_m,
100
+ BLOCK_N=block_n,
101
+ BLOCK_K=block_k,
102
+ num_stages=num_stages,
103
+ num_warps=num_warps,
104
+ matrix_instr_nonkdim=matrix_instr_nonkdim,
105
+ )
106
+ else:
107
+ if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
108
+ used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
109
+ yield triton_config(
110
+ BLOCK_M=block_m,
111
+ BLOCK_N=block_n,
112
+ BLOCK_K=block_k,
113
+ num_stages=num_stages,
114
+ num_warps=num_warps,
115
+ )
116
+
117
+
118
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
119
+ # will be utilised on the target platform. The configs are as follows:
120
+ # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
121
+ mm_kernel_configs = (
122
+ [
123
+ {"config": (32, 32, 16, 1, 2), "cond": True},
124
+ {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
125
+ {"config": (32, 64, 32, 5, 8), "cond": True},
126
+ {"config": (64, 32, 32, 5, 8), "cond": True},
127
+ {"config": (64, 32, 128, 5, 4), "cond": True},
128
+ {"config": (64, 64, 16, 2, 4), "cond": True},
129
+ {"config": (64, 64, 32, 2, 4), "cond": True},
130
+ {"config": (64, 64, 64, 3, 8), "cond": True},
131
+ {"config": (64, 64, 128, 5, 4), "cond": True},
132
+ {"config": (64, 128, 32, 3, 4), "cond": True},
133
+ {"config": (64, 128, 32, 4, 8), "cond": True},
134
+ {"config": (64, 128, 64, 3, 4), "cond": True},
135
+ {"config": (64, 128, 128, 4, 4), "cond": True},
136
+ {"config": (128, 64, 32, 3, 4), "cond": True},
137
+ {"config": (128, 64, 32, 4, 8), "cond": True},
138
+ {"config": (128, 128, 32, 2, 8), "cond": True},
139
+ {"config": (128, 128, 32, 3, 4), "cond": True},
140
+ {"config": (128, 128, 64, 3, 4), "cond": True},
141
+ {"config": (128, 128, 64, 5, 8), "cond": True},
142
+ ]
143
+ if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
144
+ else [
145
+ {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
146
+ for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
147
+ [16, 32, 64, 128, 256], repeat=3
148
+ )
149
+ for num_stages in [1, 2, 3, 4, 5]
150
+ for num_warps in [2, 4, 8]
151
+ ]
152
+ )
153
+
154
+ # these are only used in tuned_mm when AutoHeuristic is enabled
155
+ # the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
156
+ # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
157
+ # which saves compilation time (since less configs are autotuned) and potentially increase performance
158
+ # because the learned heuristic might predict a config that is not part mm_configs
159
+ extra_mm_kernel_configs = [
160
+ {"config": (16, 32, 16, 3, 2), "cond": True},
161
+ {"config": (16, 32, 32, 4, 2), "cond": True},
162
+ {"config": (16, 32, 32, 5, 2), "cond": True},
163
+ {"config": (64, 64, 128, 3, 4), "cond": True},
164
+ {"config": (128, 64, 32, 2, 2), "cond": True},
165
+ {"config": (128, 64, 64, 3, 8), "cond": True},
166
+ {"config": (128, 64, 128, 4, 8), "cond": True},
167
+ {"config": (128, 128, 32, 4, 4), "cond": True},
168
+ {"config": (128, 128, 64, 3, 8), "cond": True},
169
+ {"config": (128, 128, 64, 5, 4), "cond": True},
170
+ ]
171
+
172
+ int8_mm_kernel_configs = [
173
+ {"config": (64, 64, 32, 2, 4), "cond": True},
174
+ {"config": (64, 128, 32, 3, 4), "cond": True},
175
+ {"config": (128, 64, 32, 3, 4), "cond": True},
176
+ {"config": (64, 128, 32, 4, 8), "cond": True},
177
+ {"config": (128, 64, 32, 4, 8), "cond": True},
178
+ {"config": (64, 32, 32, 5, 8), "cond": True},
179
+ {"config": (32, 64, 32, 5, 8), "cond": True},
180
+ {"config": (128, 128, 32, 2, 8), "cond": True},
181
+ {"config": (64, 64, 64, 3, 8), "cond": True},
182
+ # {"config": (32, 32, 128, 2, 4), "cond": True},
183
+ # {"config": (64, 64, 16, 2, 4), "cond": True},
184
+ # {"config": (32, 32, 16, 1, 2), "cond": True},
185
+ {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
186
+ {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
187
+ ]
188
+
189
+ # Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
190
+ mixed_mm_kernel_configs_small_m = [
191
+ {"config": (16, 128, 256, 3, 4), "cond": True},
192
+ {"config": (16, 128, 256, 5, 8), "cond": True},
193
+ ]
194
+
195
+ mixed_mm_kernel_configs = (
196
+ mm_kernel_configs + mixed_mm_kernel_configs_small_m
197
+ if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
198
+ else mm_kernel_configs
199
+ )
200
+
201
+ scaled_mm_kernel_configs = [
202
+ {"config": (128, 256, 32, 3, 8), "cond": True},
203
+ {"config": (256, 128, 32, 3, 8), "cond": True},
204
+ {"config": (256, 64, 32, 4, 4), "cond": True},
205
+ {"config": (64, 256, 32, 4, 4), "cond": True},
206
+ {"config": (128, 128, 32, 4, 4), "cond": True},
207
+ {"config": (128, 64, 32, 4, 4), "cond": True},
208
+ {"config": (64, 128, 32, 4, 4), "cond": True},
209
+ {"config": (128, 32, 32, 4, 4), "cond": True},
210
+ {"config": (64, 32, 32, 5, 2), "cond": True},
211
+ {"config": (256, 128, 128, 3, 8), "cond": True},
212
+ {"config": (256, 64, 128, 4, 4), "cond": True},
213
+ {"config": (64, 256, 128, 4, 4), "cond": True},
214
+ {"config": (128, 128, 128, 4, 4), "cond": True},
215
+ {"config": (128, 64, 64, 4, 4), "cond": True},
216
+ {"config": (64, 128, 64, 4, 4), "cond": True},
217
+ {"config": (128, 32, 64, 4, 4), "cond": True},
218
+ {"config": (64, 32, 64, 5, 2), "cond": True},
219
+ {"config": (16, 32, 32, 2, 2), "cond": True},
220
+ {"config": (16, 64, 32, 2, 2), "cond": True},
221
+ {"config": (16, 128, 32, 2, 4), "cond": True},
222
+ {"config": (16, 256, 32, 2, 4), "cond": True},
223
+ {"config": (16, 32, 64, 2, 2), "cond": True},
224
+ {"config": (16, 64, 64, 2, 2), "cond": True},
225
+ {"config": (16, 128, 64, 2, 4), "cond": True},
226
+ {"config": (16, 256, 64, 2, 4), "cond": True},
227
+ {"config": (32, 32, 32, 2, 2), "cond": True},
228
+ {"config": (32, 64, 32, 2, 2), "cond": True},
229
+ {"config": (32, 128, 32, 2, 4), "cond": True},
230
+ {"config": (32, 256, 32, 2, 4), "cond": True},
231
+ {"config": (32, 32, 64, 2, 2), "cond": True},
232
+ {"config": (32, 64, 64, 2, 2), "cond": True},
233
+ {"config": (32, 128, 64, 2, 4), "cond": True},
234
+ {"config": (32, 256, 64, 2, 4), "cond": True},
235
+ {"config": (16, 32, 32, 3, 2), "cond": True},
236
+ {"config": (16, 64, 32, 3, 2), "cond": True},
237
+ {"config": (16, 128, 32, 3, 4), "cond": True},
238
+ {"config": (16, 256, 32, 3, 4), "cond": True},
239
+ {"config": (16, 32, 64, 3, 2), "cond": True},
240
+ {"config": (16, 64, 64, 3, 2), "cond": True},
241
+ {"config": (16, 128, 64, 3, 4), "cond": True},
242
+ {"config": (16, 256, 64, 3, 4), "cond": True},
243
+ {"config": (32, 32, 32, 3, 2), "cond": True},
244
+ {"config": (32, 64, 32, 3, 2), "cond": True},
245
+ {"config": (32, 128, 32, 3, 4), "cond": True},
246
+ {"config": (32, 256, 32, 3, 4), "cond": True},
247
+ {"config": (32, 32, 64, 3, 2), "cond": True},
248
+ {"config": (32, 64, 64, 3, 2), "cond": True},
249
+ {"config": (32, 128, 64, 3, 4), "cond": True},
250
+ {"config": (32, 256, 64, 3, 4), "cond": True},
251
+ {"config": (16, 32, 32, 4, 2), "cond": True},
252
+ {"config": (16, 64, 32, 4, 2), "cond": True},
253
+ {"config": (16, 128, 32, 4, 4), "cond": True},
254
+ {"config": (16, 256, 32, 4, 4), "cond": True},
255
+ {"config": (16, 32, 64, 4, 2), "cond": True},
256
+ {"config": (16, 64, 64, 4, 2), "cond": True},
257
+ {"config": (16, 128, 64, 4, 4), "cond": True},
258
+ {"config": (16, 256, 64, 4, 4), "cond": True},
259
+ {"config": (32, 32, 32, 4, 2), "cond": True},
260
+ {"config": (32, 64, 32, 4, 2), "cond": True},
261
+ {"config": (32, 128, 32, 4, 4), "cond": True},
262
+ {"config": (32, 256, 32, 4, 4), "cond": True},
263
+ {"config": (32, 32, 64, 4, 2), "cond": True},
264
+ {"config": (32, 64, 64, 4, 2), "cond": True},
265
+ {"config": (32, 128, 64, 4, 4), "cond": True},
266
+ {"config": (32, 256, 64, 4, 4), "cond": True},
267
+ {"config": (16, 32, 32, 5, 2), "cond": True},
268
+ {"config": (16, 64, 32, 5, 2), "cond": True},
269
+ {"config": (16, 128, 32, 5, 4), "cond": True},
270
+ {"config": (16, 256, 32, 5, 4), "cond": True},
271
+ {"config": (16, 32, 64, 5, 2), "cond": True},
272
+ {"config": (16, 64, 64, 5, 2), "cond": True},
273
+ {"config": (16, 128, 64, 5, 4), "cond": True},
274
+ {"config": (16, 256, 64, 5, 4), "cond": True},
275
+ {"config": (32, 32, 32, 5, 2), "cond": True},
276
+ {"config": (32, 64, 32, 5, 2), "cond": True},
277
+ {"config": (32, 128, 32, 5, 4), "cond": True},
278
+ {"config": (32, 256, 32, 5, 4), "cond": True},
279
+ {"config": (32, 32, 64, 5, 2), "cond": True},
280
+ {"config": (32, 64, 64, 5, 2), "cond": True},
281
+ {"config": (32, 128, 64, 5, 4), "cond": True},
282
+ {"config": (32, 256, 64, 5, 4), "cond": True},
283
+ {"config": (16, 32, 32, 6, 2), "cond": True},
284
+ {"config": (16, 64, 32, 6, 2), "cond": True},
285
+ {"config": (16, 128, 32, 6, 4), "cond": True},
286
+ {"config": (16, 256, 32, 6, 4), "cond": True},
287
+ {"config": (16, 32, 64, 6, 2), "cond": True},
288
+ {"config": (16, 64, 64, 6, 2), "cond": True},
289
+ {"config": (16, 128, 64, 6, 4), "cond": True},
290
+ {"config": (16, 256, 64, 6, 4), "cond": True},
291
+ {"config": (32, 32, 32, 6, 2), "cond": True},
292
+ {"config": (32, 64, 32, 6, 2), "cond": True},
293
+ {"config": (32, 128, 32, 6, 4), "cond": True},
294
+ {"config": (32, 256, 32, 6, 4), "cond": True},
295
+ {"config": (32, 32, 64, 6, 2), "cond": True},
296
+ {"config": (32, 64, 64, 6, 2), "cond": True},
297
+ {"config": (32, 128, 64, 6, 4), "cond": True},
298
+ {"config": (32, 256, 64, 6, 4), "cond": True},
299
+ ]
300
+
301
+
302
+ # Create filtered list of configs based on cond evaluation
303
+ mm_platform_configs = tuple(
304
+ cast(Tuple[int, int, int, int, int], config["config"])
305
+ for config in mm_kernel_configs
306
+ if config["cond"]
307
+ )
308
+ extra_mm_platform_configs = tuple(
309
+ cast(Tuple[int, int, int, int, int], config["config"])
310
+ for config in extra_mm_kernel_configs
311
+ if config["cond"]
312
+ )
313
+ int8_platform_configs = tuple(
314
+ cast(Tuple[int, int, int, int, int], config["config"])
315
+ for config in int8_mm_kernel_configs
316
+ if config["cond"]
317
+ )
318
+ mixed_mm_platform_configs = tuple(
319
+ cast(Tuple[int, int, int, int, int], config["config"])
320
+ for config in mixed_mm_kernel_configs
321
+ if config["cond"]
322
+ )
323
+ scaled_mm_platform_configs = tuple(
324
+ cast(Tuple[int, int, int, int, int], config["config"])
325
+ for config in scaled_mm_kernel_configs
326
+ if config["cond"]
327
+ )
328
+
329
+ # On ROCm convert num_stages to 0 to enable software pipelining
330
+ if torch.version.hip:
331
+ mm_platform_configs = tuple(
332
+ (config[0], config[1], config[2], 0, config[4])
333
+ for config in mm_platform_configs
334
+ )
335
+ extra_mm_platform_configs = tuple(
336
+ (config[0], config[1], config[2], 0, config[4])
337
+ for config in extra_mm_platform_configs
338
+ )
339
+ int8_platform_configs = tuple(
340
+ (config[0], config[1], config[2], 0, config[4])
341
+ for config in mm_platform_configs
342
+ )
343
+ mixed_mm_platform_configs = tuple(
344
+ (config[0], config[1], config[2], 0, config[4])
345
+ for config in mixed_mm_platform_configs
346
+ )
347
+ scaled_mm_platform_configs = tuple(
348
+ (config[0], config[1], config[2], 0, config[4])
349
+ for config in scaled_mm_platform_configs
350
+ )
351
+
352
+ mm_configs = functools.partial(
353
+ filtered_configs,
354
+ configs=mm_platform_configs,
355
+ )
356
+
357
+ extra_mm_configs = functools.partial(
358
+ filtered_configs,
359
+ configs=extra_mm_platform_configs,
360
+ )
361
+
362
+ int8_mm_configs = functools.partial(
363
+ filtered_configs,
364
+ configs=int8_platform_configs,
365
+ )
366
+
367
+ mixed_mm_configs = functools.partial(
368
+ filtered_configs,
369
+ configs=mixed_mm_platform_configs,
370
+ )
371
+
372
+ scaled_mm_configs = functools.partial(
373
+ filtered_configs,
374
+ configs=scaled_mm_platform_configs,
375
+ )
376
+
377
+
378
+ def mm_grid(m, n, meta):
379
+ """
380
+ The CUDA grid size for matmul triton templates.
381
+ """
382
+ return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
383
+
384
+
385
+ def acc_type(dtype):
386
+ if dtype in (torch.float16, torch.bfloat16):
387
+ return "tl.float32"
388
+ return f"tl.{dtype}".replace("torch.", "")
389
+
390
+
391
+ def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
392
+ """
393
+ Common options to matmul triton templates.
394
+ """
395
+ even_k_symbolic = (
396
+ # it isn't worth guarding on this
397
+ sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
398
+ == config.kwargs["BLOCK_K"]
399
+ )
400
+ allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
401
+ not inductor_config.force_same_precision
402
+ or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
403
+ )
404
+ return dict(
405
+ GROUP_M=8,
406
+ EVEN_K=even_k_symbolic,
407
+ ALLOW_TF32=allow_tf32,
408
+ ACC_TYPE=acc_type(layout.dtype),
409
+ B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
410
+ num_stages=config.num_stages,
411
+ num_warps=config.num_warps,
412
+ **config.kwargs,
413
+ )
414
+
415
+
416
+ def mm_args(
417
+ mat1,
418
+ mat2,
419
+ *others,
420
+ layout=None,
421
+ out_dtype=None,
422
+ use_4x2_dim=False,
423
+ mat2_transposed=False,
424
+ ):
425
+ """
426
+ Common arg processing for mm,bmm,addmm,etc
427
+ """
428
+ mat1, mat2 = realize_inputs(mat1, mat2)
429
+ *b1, m, k1 = mat1.get_size()
430
+ if mat2_transposed:
431
+ *b2, n, k2 = mat2.get_size()
432
+ else:
433
+ *b2, k2, n = mat2.get_size()
434
+ b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
435
+ if use_4x2_dim:
436
+ k2 = k2 * 2
437
+ k = V.graph.sizevars.guard_equals(k1, k2)
438
+ if layout is None:
439
+ from torch._inductor.ir import FixedLayout
440
+
441
+ if out_dtype is None:
442
+ out_dtype = mat1.get_dtype()
443
+
444
+ layout = FixedLayout(
445
+ mat1.get_device(),
446
+ out_dtype,
447
+ [*b, m, n],
448
+ )
449
+ else:
450
+ assert out_dtype is None, "out_dtype is ignored if layout is specified."
451
+ from ..lowering import expand
452
+
453
+ others = [realize_inputs(expand(x, layout.size)) for x in others]
454
+
455
+ return [m, n, k, layout, mat1, mat2, *others]
456
+
457
+
458
+ def addmm_epilogue(dtype, alpha, beta):
459
+ def epilogue(acc, bias):
460
+ if alpha != 1:
461
+ acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
462
+ if beta != 1:
463
+ bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
464
+ return V.ops.add(acc, bias)
465
+
466
+ return epilogue
.venv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import functools
3
+
4
+ import torch
5
+
6
+ from ..lowering import lowerings
7
+ from ..select_algorithm import (
8
+ autotune_select_algorithm,
9
+ ExternKernelChoice,
10
+ TritonTemplate,
11
+ )
12
+ from ..utils import use_aten_gemm_kernels, use_triton_template
13
+ from ..virtualized import V
14
+ from .mm_common import mm_args, mm_grid, mm_options
15
+
16
+
17
+ aten = torch.ops.aten
18
+
19
+ aten_mm_plus_mm = ExternKernelChoice(
20
+ torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm"
21
+ )
22
+
23
+ mm_plus_mm_template = TritonTemplate(
24
+ name="mm_plus_mm",
25
+ grid=mm_grid,
26
+ debug=False,
27
+ source=r"""
28
+ {{def_kernel("A", "B", "C", "D")}}
29
+ M = {{size("A", 0)}}
30
+ N = {{size("B", 1)}}
31
+ K1 = {{size("A", 1)}}
32
+ if M * N == 0:
33
+ # early exit due to zero-size input(s)
34
+ return
35
+ # K2 = {{size("C", 1)}}
36
+ stride_am = {{stride("A", 0)}}
37
+ stride_ak = {{stride("A", 1)}}
38
+ stride_bk = {{stride("B", 0)}}
39
+ stride_bn = {{stride("B", 1)}}
40
+ stride_cm = {{stride("C", 0)}}
41
+ stride_ck = {{stride("C", 1)}}
42
+ stride_dk = {{stride("D", 0)}}
43
+ stride_dn = {{stride("D", 1)}}
44
+
45
+ # based on triton.ops.matmul
46
+ pid = tl.program_id(0)
47
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
48
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
49
+
50
+ # re-order program ID for better L2 performance
51
+ width = GROUP_M * grid_n
52
+ group_id = pid // width
53
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
54
+ pid_m = group_id * GROUP_M + (pid % group_size)
55
+ pid_n = (pid % width) // (group_size)
56
+
57
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
58
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
59
+
60
+ if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1))
61
+ and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))):
62
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
63
+ else:
64
+ ram = rm % M
65
+
66
+ if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1))
67
+ and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))):
68
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
69
+ else:
70
+ rbn = rn % N
71
+
72
+ rk = tl.arange(0, BLOCK_K)
73
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
74
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
75
+ C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck)
76
+ D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn)
77
+
78
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
79
+ for k1 in range(K1, 0, -BLOCK_K):
80
+ # First matmul with A @ B
81
+ if EVEN_K:
82
+ a = tl.load(A)
83
+ b = tl.load(B)
84
+ else:
85
+ a = tl.load(A, mask=rk[None, :] < k1, other=0.)
86
+ b = tl.load(B, mask=rk[:, None] < k1, other=0.)
87
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
88
+ A += BLOCK_K * stride_ak
89
+ B += BLOCK_K * stride_bk
90
+
91
+ for k2 in range(K1, 0, -BLOCK_K):
92
+
93
+ # Second matmul with C @ D
94
+ if EVEN_K:
95
+ c = tl.load(C)
96
+ d = tl.load(D)
97
+ else:
98
+ c = tl.load(C, mask=rk[None, :] < k2, other=0.)
99
+ d = tl.load(D, mask=rk[:, None] < k2, other=0.)
100
+ acc += tl.dot(c, d, allow_tf32=ALLOW_TF32)
101
+ C += BLOCK_K * stride_ck
102
+ D += BLOCK_K * stride_dk
103
+
104
+
105
+ idx_m = rm[:, None]
106
+ idx_n = rn[None, :]
107
+ mask = (idx_m < M) & (idx_n < N)
108
+
109
+ # inductor generates a suffix
110
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
111
+ """,
112
+ )
113
+
114
+
115
+ @functools.lru_cache(None)
116
+ def mm_configs():
117
+ import triton
118
+
119
+ # List of dictionaries to store the kernel configs. Configs that evaluate to true
120
+ # will be utilised on the target platform
121
+ mm_triton_configs = [
122
+ {
123
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
124
+ "num_stages": 2,
125
+ "num_warps": 4,
126
+ "cond": True,
127
+ },
128
+ {
129
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
130
+ "num_stages": 3,
131
+ "num_warps": 8,
132
+ "cond": True,
133
+ },
134
+ {
135
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32},
136
+ "num_stages": 4,
137
+ "num_warps": 16,
138
+ "cond": True,
139
+ },
140
+ {
141
+ "config": {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32},
142
+ "num_stages": 4,
143
+ "num_warps": 8,
144
+ "cond": True,
145
+ },
146
+ {
147
+ "config": {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32},
148
+ "num_stages": 4,
149
+ "num_warps": 8,
150
+ "cond": True,
151
+ },
152
+ {
153
+ "config": {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},
154
+ "num_stages": 1,
155
+ "num_warps": 8,
156
+ "cond": True,
157
+ },
158
+ {
159
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64},
160
+ "num_stages": 1,
161
+ "num_warps": 8,
162
+ "cond": True,
163
+ },
164
+ {
165
+ "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128},
166
+ "num_stages": 1,
167
+ "num_warps": 8,
168
+ "cond": torch.version.hip is None,
169
+ },
170
+ {
171
+ "config": {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16},
172
+ "num_stages": 2,
173
+ "num_warps": 4,
174
+ "cond": True,
175
+ },
176
+ {
177
+ "config": {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16},
178
+ "num_stages": 1,
179
+ "num_warps": 2,
180
+ "cond": True,
181
+ },
182
+ ]
183
+
184
+ # Filter out configs in which cond evaluates to true
185
+ # On ROCm convert num_stages to 1 as pipelining provides no benefit
186
+ if torch.version.hip:
187
+ filtered_configs = [
188
+ triton.Config(c["config"], num_stages=1, num_warps=c["num_warps"])
189
+ for c in mm_triton_configs
190
+ if c["cond"]
191
+ ]
192
+ else:
193
+ filtered_configs = [
194
+ triton.Config(
195
+ c["config"], num_stages=c["num_stages"], num_warps=c["num_warps"]
196
+ )
197
+ for c in mm_triton_configs
198
+ if c["cond"]
199
+ ]
200
+
201
+ return filtered_configs
202
+
203
+
204
+ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
205
+ """
206
+ Computes mm(mat1, mat2) + mm(mat3, mat4)
207
+ """
208
+ m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
209
+ m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout)
210
+ # Optimization is optional, because we can always just not do the fusion
211
+ if (
212
+ m1 * n1 == 0
213
+ or m2 * n2 == 0
214
+ or not V.graph.sizevars.statically_known_list_equals(
215
+ mat1.get_size(), mat3.get_size()
216
+ )
217
+ or not V.graph.sizevars.statically_known_list_equals(
218
+ mat2.get_size(), mat4.get_size()
219
+ )
220
+ ):
221
+ # TODO(jansel): support different K values when this is fixed:
222
+ # https://github.com/openai/triton/issues/967
223
+ return lowerings[aten.add](
224
+ lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4)
225
+ )
226
+
227
+ assert layout1 == layout2
228
+ # options to tune from
229
+ choices = (
230
+ [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)]
231
+ if use_aten_gemm_kernels()
232
+ else []
233
+ )
234
+ if use_triton_template(layout1):
235
+ for config in mm_configs():
236
+ # see https://github.com/openai/triton/issues/1298
237
+ # BLOCK_K = K causes llvm error
238
+ if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1):
239
+ mm_plus_mm_template.maybe_append_choice(
240
+ choices,
241
+ input_nodes=(mat1, mat2, mat3, mat4),
242
+ layout=layout1,
243
+ **mm_options(config, m1, n1, k1, layout1),
244
+ )
245
+
246
+ return autotune_select_algorithm(
247
+ "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1
248
+ )
.venv/Lib/site-packages/torch/_inductor/kernel/mm_scaled.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import sympy
5
+
6
+ import torch
7
+
8
+ from .. import config as inductor_config
9
+ from ..ir import ChoiceCaller, Layout, StorageBox, TensorBox
10
+ from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering
11
+ from ..select_algorithm import (
12
+ autotune_select_algorithm,
13
+ ExternKernelChoice,
14
+ NoValidChoicesError,
15
+ realize_inputs,
16
+ TritonTemplate,
17
+ )
18
+ from ..utils import use_aten_gemm_kernels, use_triton_template
19
+ from .mm import _is_static_problem # TODO(yangsiyu) move to mm_common
20
+ from .mm_common import mm_args, mm_grid, scaled_mm_configs
21
+
22
+
23
+ log = logging.getLogger(__name__)
24
+ aten = torch.ops.aten
25
+
26
+
27
+ scaled_mm_template = TritonTemplate(
28
+ name="scaled_mm",
29
+ grid=mm_grid,
30
+ source=r"""
31
+ {{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}}
32
+ M = {{size("A", 0)}}
33
+ N = {{size("B", 1)}}
34
+ K = {{size("A", 1)}}
35
+ if M * N == 0:
36
+ # early exit due to zero-size input(s)
37
+ return
38
+ stride_am = {{stride("A", 0)}}
39
+ stride_ak = {{stride("A", 1)}}
40
+ stride_bk = {{stride("B", 0)}}
41
+ stride_bn = {{stride("B", 1)}}
42
+
43
+ # based on triton.ops.matmul
44
+ pid = tl.program_id(0)
45
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
46
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
47
+
48
+ # re-order program ID for better L2 performance
49
+ width = GROUP_M * grid_n
50
+ group_id = pid // width
51
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
52
+ pid_m = group_id * GROUP_M + (pid % group_size)
53
+ pid_n = (pid % width) // (group_size)
54
+
55
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
56
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
57
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
58
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
59
+ rk = tl.arange(0, BLOCK_K)
60
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
61
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
62
+
63
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
64
+ for k in range(K, 0, -BLOCK_K):
65
+ if EVEN_K:
66
+ a = tl.load(A)
67
+ b = tl.load(B)
68
+ else:
69
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
70
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
71
+ if B_PROLOGUE_CAST_TYPE is not None:
72
+ b = b.to(B_PROLOGUE_CAST_TYPE)
73
+ if USE_FAST_ACCUM:
74
+ acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
75
+ else:
76
+ acc += tl.dot(a, b, out_dtype=ACC_TYPE)
77
+ A += BLOCK_K * stride_ak
78
+ B += BLOCK_K * stride_bk
79
+
80
+ if SCALING_ROWWISE:
81
+ inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
82
+ inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
83
+ inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
84
+ acc *= inv_scale_row
85
+ else:
86
+ # for tensor-wise scaling, the scales are scalars
87
+ inv_a_scale = tl.load(A_inverse_scale)
88
+ inv_b_scale = tl.load(B_inverse_scale)
89
+ inv_scale = inv_a_scale * inv_b_scale
90
+ acc *= inv_scale
91
+
92
+ # rematerialize rm and rn to save registers
93
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
94
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
95
+
96
+ idx_m = rm[:, None]
97
+ idx_n = rn[None, :]
98
+ mask = (idx_m < M) & (idx_n < N)
99
+
100
+ # inductor generates a suffix
101
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
102
+ """,
103
+ )
104
+
105
+
106
+ # Inductor does not allow optional tensor input arguments currently (pass None as an
107
+ # input node to template choices), but since for _scaled_mm there is only one such arg
108
+ # (bias), work around by having a second template when bias is provided.
109
+ scaled_mm_bias_template = TritonTemplate(
110
+ name="scaled_mm_bias",
111
+ grid=mm_grid,
112
+ source=r"""
113
+ {{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}}
114
+ M = {{size("A", 0)}}
115
+ N = {{size("B", 1)}}
116
+ K = {{size("A", 1)}}
117
+ if M * N == 0:
118
+ # early exit due to zero-size input(s)
119
+ return
120
+ stride_am = {{stride("A", 0)}}
121
+ stride_ak = {{stride("A", 1)}}
122
+ stride_bk = {{stride("B", 0)}}
123
+ stride_bn = {{stride("B", 1)}}
124
+
125
+ # based on triton.ops.matmul
126
+ pid = tl.program_id(0)
127
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
128
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
129
+
130
+ # re-order program ID for better L2 performance
131
+ width = GROUP_M * grid_n
132
+ group_id = pid // width
133
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
134
+ pid_m = group_id * GROUP_M + (pid % group_size)
135
+ pid_n = (pid % width) // (group_size)
136
+
137
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
138
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
139
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
140
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
141
+ rk = tl.arange(0, BLOCK_K)
142
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
143
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
144
+
145
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
146
+ for k in range(K, 0, -BLOCK_K):
147
+ if EVEN_K:
148
+ a = tl.load(A)
149
+ b = tl.load(B)
150
+ else:
151
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
152
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
153
+ if B_PROLOGUE_CAST_TYPE is not None:
154
+ b = b.to(B_PROLOGUE_CAST_TYPE)
155
+ if USE_FAST_ACCUM:
156
+ acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE)
157
+ else:
158
+ acc += tl.dot(a, b, out_dtype=ACC_TYPE)
159
+ A += BLOCK_K * stride_ak
160
+ B += BLOCK_K * stride_bk
161
+
162
+ if SCALING_ROWWISE:
163
+ inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M)
164
+ inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N)
165
+ inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :]
166
+ acc *= inv_scale_row
167
+ else:
168
+ # for tensor-wise scaling, the scales are scalars
169
+ inv_a_scale = tl.load(A_inverse_scale)
170
+ inv_b_scale = tl.load(B_inverse_scale)
171
+ inv_scale = inv_a_scale * inv_b_scale
172
+ acc *= inv_scale
173
+
174
+ # rematerialize rm and rn to save registers
175
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
176
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
177
+
178
+ # bias
179
+ bias = tl.load(bias_ptr + rn, mask=rn < N)
180
+ acc += bias
181
+
182
+ idx_m = rm[:, None]
183
+ idx_n = rn[None, :]
184
+ mask = (idx_m < M) & (idx_n < N)
185
+
186
+ # inductor generates a suffix
187
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
188
+ """,
189
+ )
190
+
191
+
192
+ aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm")
193
+
194
+
195
+ def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool:
196
+ # Same sized scales are compatable
197
+ if len(size_a) == len(size_b):
198
+ return True
199
+
200
+ # Both need to be scalars or len(1) tensors
201
+ if len(size_a) <= 1 and len(size_b) <= 1:
202
+ return True
203
+
204
+ return False
205
+
206
+
207
+ def scaled_mm_options( # type: ignore[no-untyped-def]
208
+ config, # triton.Config
209
+ sym_m: sympy.core.numbers.Integer,
210
+ sym_n: sympy.core.numbers.Integer,
211
+ sym_k: sympy.core.numbers.Integer,
212
+ layout: Layout,
213
+ scale_a: StorageBox,
214
+ scale_b: StorageBox,
215
+ use_fast_accum: bool,
216
+ b_prologue_cast_type: Optional[str] = None,
217
+ ) -> Dict[str, Any]:
218
+ even_k_symbolic = (
219
+ sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
220
+ )
221
+
222
+ size_a, size_b = scale_a.get_size(), scale_b.get_size()
223
+ assert are_compatible_scales(size_a, size_b), (
224
+ "Expect scale_a and scale_b to be either both scalars (including single-element tensors) "
225
+ f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}."
226
+ )
227
+ return dict(
228
+ GROUP_M=8,
229
+ EVEN_K=even_k_symbolic,
230
+ ACC_TYPE="tl.float32",
231
+ B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
232
+ USE_FAST_ACCUM=use_fast_accum,
233
+ num_stages=config.num_stages,
234
+ num_warps=config.num_warps,
235
+ # tensor-wise scaling if scalar scales
236
+ SCALING_ROWWISE=len(scale_a.get_size()) == 2,
237
+ **config.kwargs,
238
+ )
239
+
240
+
241
+ add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides)
242
+
243
+
244
+ @register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc]
245
+ def tuned_scaled_mm(
246
+ mat_a: TensorBox,
247
+ mat_b: TensorBox,
248
+ scale_a: TensorBox,
249
+ scale_b: TensorBox,
250
+ bias: Optional[TensorBox] = None,
251
+ scale_result: Optional[TensorBox] = None,
252
+ out_dtype: Optional[torch.dtype] = None,
253
+ use_fast_accum: bool = False,
254
+ layout: Optional[Layout] = None,
255
+ ) -> TensorBox:
256
+ m, n, k, layout, mat_a, mat_b = mm_args(
257
+ mat_a, mat_b, layout=layout, out_dtype=out_dtype
258
+ )
259
+ scale_a, scale_b = realize_inputs(scale_a, scale_b)
260
+
261
+ input_nodes: Tuple[Any, ...]
262
+ # workaround for Inductor not supporting optional tensor input arguments
263
+ if bias is None:
264
+ input_nodes = (mat_a, mat_b, scale_a, scale_b)
265
+ triton_template = scaled_mm_template
266
+ else:
267
+ bias = realize_inputs(bias)
268
+ input_nodes = (mat_a, mat_b, scale_a, scale_b, bias)
269
+ triton_template = scaled_mm_bias_template
270
+
271
+ aten_choice = aten__fp8_mm.bind(
272
+ input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
273
+ )
274
+
275
+ choices: List[ChoiceCaller] = []
276
+ if use_aten_gemm_kernels():
277
+ choices.append(aten_choice)
278
+
279
+ static_shape, is_nonzero = _is_static_problem([mat_a, mat_b], layout)
280
+ if is_nonzero and use_triton_template(layout, enable_float8=True):
281
+ for config in scaled_mm_configs(m, n, k):
282
+ if k == 16 and config.kwargs["BLOCK_M"] >= 64:
283
+ continue # Triton crashes in this case
284
+ kwargs = scaled_mm_options(
285
+ config, m, n, k, layout, scale_a, scale_b, use_fast_accum
286
+ )
287
+ # possibly appends a TritonTemplateCaller to choices
288
+ triton_template.maybe_append_choice(
289
+ choices,
290
+ input_nodes=input_nodes,
291
+ layout=layout,
292
+ **kwargs,
293
+ )
294
+
295
+ if (
296
+ len(choices) == 0
297
+ and not use_aten_gemm_kernels()
298
+ and inductor_config.autotune_fallback_to_aten
299
+ ):
300
+ log.warning("No choices for scaled_mm, using ATen backend as fallback")
301
+ return aten_choice.output_node()
302
+
303
+ try:
304
+ return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout)
305
+ except NoValidChoicesError:
306
+ if not inductor_config.autotune_fallback_to_aten:
307
+ raise
308
+ log.warning(
309
+ "All choices for scaled_mm were invalid, using ATen backend as fallback"
310
+ )
311
+ return aten_choice.output_node()
.venv/Lib/site-packages/torch/_inductor/kernel/unpack_mixed_mm.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ from typing import List, TYPE_CHECKING
4
+
5
+ from ..select_algorithm import autotune_select_algorithm, TritonTemplate
6
+ from .mm_common import mm_args, mm_configs, mm_grid, mm_options
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from ..ir import ChoiceCaller
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ uint4x2_mixed_mm_template = TritonTemplate(
15
+ name="uint4x2_mixed_mm",
16
+ grid=mm_grid,
17
+ source=r"""
18
+ {{def_kernel("A", "B")}}
19
+ M = {{size("A", 0)}}
20
+ N = {{size("B", 1)}}
21
+ K = {{size("A", 1)}}
22
+ stride_am = {{stride("A", 0)}}
23
+ stride_ak = {{stride("A", 1)}}
24
+ stride_bk = {{stride("B", 0)}}
25
+ stride_bn = {{stride("B", 1)}}
26
+
27
+ # based on triton.ops.matmul
28
+ pid = tl.program_id(0)
29
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
30
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
31
+
32
+ # re-order program ID for better L2 performance
33
+ width = GROUP_M * grid_n
34
+ group_id = pid // width
35
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
36
+ pid_m = group_id * GROUP_M + (pid % group_size)
37
+ pid_n = (pid % width) // (group_size)
38
+
39
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
40
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
41
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
42
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
43
+ rk = tl.arange(0, BLOCK_K)
44
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
45
+ B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
46
+ b_shifts = 4*(rk%2)
47
+ b_subs = 8*(1-(rk%2))
48
+
49
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
50
+ for k in range(K, 0, -BLOCK_K):
51
+ if EVEN_K:
52
+ a = tl.load(A)
53
+ b = tl.load(B)
54
+ else:
55
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
56
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
57
+ b = ((b >> b_shifts[:, None]) & 0xF) - 8
58
+ b = b.to(B_PROLOGUE_CAST_TYPE)
59
+ acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
60
+ A += BLOCK_K * stride_ak
61
+ B += BLOCK_K//2 * stride_bk
62
+
63
+ # rematerialize rm and rn to save registers
64
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
65
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
66
+ idx_m = rm[:, None]
67
+ idx_n = rn[None, :]
68
+ mask = (idx_m < M) & (idx_n < N)
69
+
70
+ # inductor generates a suffix
71
+ {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
72
+ """,
73
+ )
74
+
75
+
76
+ def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
77
+ m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
78
+ choices: List[ChoiceCaller] = []
79
+ b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
80
+ for config in mm_configs(m, n, k):
81
+ uint4x2_mixed_mm_template.maybe_append_choice(
82
+ choices,
83
+ input_nodes=(mat1, mat2),
84
+ layout=layout,
85
+ **mm_options(config, m, n, k, layout, b_prologue_cast_type),
86
+ )
87
+ return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)