jetro30087 commited on
Commit
7ba6b67
·
1 Parent(s): be87a49

Upload 2 files

Browse files
Files changed (2) hide show
  1. debug/mod_tir_dynamic.py +570 -0
  2. debug/mod_tir_static.py +418 -0
debug/mod_tir_dynamic.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from tvm.script import ir as I
3
+ from tvm.script import tir as T
4
+
5
+ # fmt: off
6
+ # from tvm.script import ir as I
7
+ # from tvm.script import tir as T
8
+
9
+ @I.ir_module
10
+ class Module:
11
+ @T.prim_func
12
+ def NT_matmul1(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle):
13
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
14
+ n = T.int64()
15
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
16
+ NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096)), "float16")
17
+ # with T.block("root"):
18
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
19
+ with T.block("NT_matmul"):
20
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
21
+ T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
22
+ T.writes(NT_matmul[v_i0, v_i1, v_i2])
23
+ with T.init():
24
+ NT_matmul[v_i0, v_i1, v_i2] = T.float16(0)
25
+ NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k]
26
+
27
+ @T.prim_func
28
+ def extend_te(var_A: T.handle, var_concat_te: T.handle):
29
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
30
+ n = T.int64()
31
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16")
32
+ m = T.int64()
33
+ concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16")
34
+ # with T.block("root"):
35
+ for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m):
36
+ with T.block("concat_te"):
37
+ v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j])
38
+ T.reads(A[v_b, v__, v_i, v_j + n - m])
39
+ T.writes(concat_te[v_b, v__, v_i, v_j])
40
+ concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m])
41
+
42
+ @T.prim_func
43
+ def full(var_T_full: T.handle):
44
+ T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
45
+ n = T.int64()
46
+ T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
47
+ # with T.block("root"):
48
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n):
49
+ with T.block("T_full"):
50
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
51
+ T.reads()
52
+ T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
53
+ T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504)
54
+
55
+ @T.prim_func
56
+ def fused_NT_matmul1_add1(p_lv41: T.handle, lv1386: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle):
57
+ T.func_attr({"tir.noalias": T.bool(True)})
58
+ n = T.int64()
59
+ lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16")
60
+ lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16")
61
+ var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
62
+ # with T.block("root"):
63
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
64
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
65
+ with T.block("NT_matmul"):
66
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
67
+ T.reads(lv41[v_i0, v_i1, v_k], lv1386[v_i2, v_k])
68
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
69
+ with T.init():
70
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
71
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv41[v_i0, v_i1, v_k] * lv1386[v_i2, v_k]
72
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
73
+ with T.block("T_add"):
74
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
75
+ T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
76
+ T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
77
+ var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
78
+
79
+ @T.prim_func
80
+ def fused_NT_matmul2_divide2_maximum1_minimum1_cast3(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
81
+ T.func_attr({"tir.noalias": T.bool(True)})
82
+ n = T.int64()
83
+ lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
84
+ m = T.int64()
85
+ lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
86
+ lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16")
87
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m))
88
+ # with T.block("root"):
89
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
90
+ var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
91
+ var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
92
+ var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
93
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)):
94
+ with T.block("NT_matmul"):
95
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
96
+ T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k])
97
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
98
+ with T.init():
99
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
100
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k]
101
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
102
+ with T.block("T_divide"):
103
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
104
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
105
+ T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
106
+ var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
107
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
108
+ with T.block("T_maximum"):
109
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
110
+ T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
111
+ T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
112
+ var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
113
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
114
+ with T.block("T_minimum"):
115
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
116
+ T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
117
+ T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
118
+ var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
119
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
120
+ with T.block("compute"):
121
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
122
+ T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
123
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
124
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
125
+
126
+ @T.prim_func
127
+ def fused_NT_matmul3_multiply1(p_lv45: T.handle, lv1400: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_lv50: T.handle, p_output0: T.handle):
128
+ T.func_attr({"tir.noalias": T.bool(True)})
129
+ n = T.int64()
130
+ lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
131
+ lv50 = T.match_buffer(p_lv50, (T.int64(1), n, T.int64(11008)), "float16")
132
+ var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
133
+ # with T.block("root"):
134
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
135
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
136
+ with T.block("NT_matmul"):
137
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
138
+ T.reads(lv45[v_i0, v_i1, v_k], lv1400[v_i2, v_k])
139
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
140
+ with T.init():
141
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
142
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv1400[v_i2, v_k]
143
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
144
+ with T.block("T_multiply"):
145
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
146
+ T.reads(lv50[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
147
+ T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
148
+ var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv50[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
149
+
150
+ @T.prim_func
151
+ def fused_NT_matmul3_silu1(p_lv45: T.handle, lv1393: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_output0: T.handle):
152
+ T.func_attr({"tir.noalias": T.bool(True)})
153
+ n = T.int64()
154
+ lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
155
+ var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
156
+ # with T.block("root"):
157
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
158
+ compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
159
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
160
+ with T.block("NT_matmul"):
161
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
162
+ T.reads(lv45[v_i0, v_i1, v_k], lv1393[v_i2, v_k])
163
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
164
+ with T.init():
165
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
166
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv1393[v_i2, v_k]
167
+ for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)):
168
+ with T.block("compute"):
169
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
170
+ T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
171
+ T.writes(compute[v_i0, v_i1, v_i2])
172
+ compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
173
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
174
+ with T.block("T_multiply"):
175
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
176
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
177
+ T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
178
+ var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
179
+
180
+ @T.prim_func
181
+ def fused_NT_matmul4_add1(p_lv51: T.handle, lv1407: T.Buffer((T.int64(4096), T.int64(11008)), "float16"), p_lv44: T.handle, p_output0: T.handle):
182
+ T.func_attr({"tir.noalias": T.bool(True)})
183
+ n = T.int64()
184
+ lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(11008)), "float16")
185
+ lv44 = T.match_buffer(p_lv44, (T.int64(1), n, T.int64(4096)), "float16")
186
+ var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
187
+ # with T.block("root"):
188
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
189
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):
190
+ with T.block("NT_matmul"):
191
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
192
+ T.reads(lv51[v_i0, v_i1, v_k], lv1407[v_i2, v_k])
193
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
194
+ with T.init():
195
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
196
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv1407[v_i2, v_k]
197
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
198
+ with T.block("T_add"):
199
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
200
+ T.reads(lv44[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
201
+ T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
202
+ var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv44[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
203
+
204
+ @T.prim_func
205
+ def fused_NT_matmul_divide1_maximum_minimum_cast(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle):
206
+ T.func_attr({"tir.noalias": T.bool(True)})
207
+ n = T.int64()
208
+ lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
209
+ lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
210
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n))
211
+ # with T.block("root"):
212
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
213
+ var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
214
+ var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
215
+ var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
216
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):
217
+ with T.block("NT_matmul"):
218
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
219
+ T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k])
220
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
221
+ with T.init():
222
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
223
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k]
224
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
225
+ with T.block("T_divide"):
226
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
227
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
228
+ T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
229
+ var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
230
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
231
+ with T.block("T_maximum"):
232
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
233
+ T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
234
+ T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
235
+ var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
236
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
237
+ with T.block("T_minimum"):
238
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
239
+ T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
240
+ T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
241
+ var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
242
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
243
+ with T.block("compute"):
244
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
245
+ T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
246
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
247
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
248
+
249
+ @T.prim_func
250
+ def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64):
251
+ T.func_attr({"tir.noalias": T.bool(True)})
252
+ var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16")
253
+ # with T.block("root"):
254
+ var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16")
255
+ for i, j in T.grid(n, n):
256
+ with T.block("make_diag_mask_te"):
257
+ v_i, v_j = T.axis.remap("SS", [i, j])
258
+ T.reads()
259
+ T.writes(var_make_diag_mask_te_intermediate[v_i, v_j])
260
+ var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504))
261
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n):
262
+ with T.block("T_broadcast_to"):
263
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
264
+ T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3])
265
+ T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
266
+ var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3]
267
+
268
+ @T.prim_func
269
+ def fused_softmax1_cast1(p_lv1613: T.handle, p_output0: T.handle):
270
+ T.func_attr({"tir.noalias": T.bool(True)})
271
+ n = T.int64()
272
+ lv1613 = T.match_buffer(p_lv1613, (T.int64(1), T.int64(32), T.int64(1), n))
273
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
274
+ # with T.block("root"):
275
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
276
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
277
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
278
+ var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
279
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
280
+ with T.block("T_softmax_maxelem"):
281
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
282
+ T.reads(lv1613[v_i0, v_i1, v_i2, v_k])
283
+ T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
284
+ with T.init():
285
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
286
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1613[v_i0, v_i1, v_i2, v_k])
287
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
288
+ with T.block("T_softmax_exp"):
289
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
290
+ T.reads(lv1613[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
291
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
292
+ T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1613[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
293
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
294
+ with T.block("T_softmax_expsum"):
295
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
296
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
297
+ T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
298
+ with T.init():
299
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
300
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
301
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
302
+ with T.block("T_softmax_norm"):
303
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
304
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
305
+ T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
306
+ T.block_attr({"axis": 3})
307
+ var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
308
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
309
+ with T.block("compute"):
310
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
311
+ T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
312
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
313
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
314
+
315
+ @T.prim_func
316
+ def fused_softmax2_cast4(p_lv36: T.handle, p_output0: T.handle):
317
+ T.func_attr({"tir.noalias": T.bool(True)})
318
+ n, m = T.int64(), T.int64()
319
+ lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, m))
320
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16")
321
+ # with T.block("root"):
322
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n))
323
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
324
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n))
325
+ var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
326
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
327
+ with T.block("T_softmax_maxelem"):
328
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
329
+ T.reads(lv36[v_i0, v_i1, v_i2, v_k])
330
+ T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
331
+ with T.init():
332
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
333
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv36[v_i0, v_i1, v_i2, v_k])
334
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
335
+ with T.block("T_softmax_exp"):
336
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
337
+ T.reads(lv36[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
338
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
339
+ T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv36[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
340
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
341
+ with T.block("T_softmax_expsum"):
342
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
343
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
344
+ T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
345
+ with T.init():
346
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
347
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
348
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
349
+ with T.block("T_softmax_norm"):
350
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
351
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
352
+ T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
353
+ T.block_attr({"axis": 3})
354
+ var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
355
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
356
+ with T.block("compute"):
357
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
358
+ T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
359
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
360
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
361
+
362
+ @T.prim_func
363
+ def matmul10(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
364
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
365
+ n, m = T.int64(), T.int64()
366
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16")
367
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
368
+ matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
369
+ # with T.block("root"):
370
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m):
371
+ with T.block("matmul"):
372
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
373
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
374
+ T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
375
+ with T.init():
376
+ matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
377
+ matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
378
+
379
+ @T.prim_func
380
+ def matmul5(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
381
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
382
+ n = T.int64()
383
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
384
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
385
+ # with T.block("root"):
386
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
387
+ with T.block("matmul"):
388
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
389
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
390
+ T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
391
+ with T.init():
392
+ matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
393
+ matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
394
+
395
+ @T.prim_func
396
+ def reshape3(var_A: T.handle, var_T_reshape: T.handle):
397
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
398
+ n = T.int64()
399
+ A = T.match_buffer(var_A, (n, T.int64(32), T.int64(128)), "float16")
400
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
401
+ # with T.block("root"):
402
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
403
+ with T.block("T_reshape"):
404
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
405
+ T.reads(A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)])
406
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
407
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)]
408
+
409
+ @T.prim_func
410
+ def reshape5(var_A: T.handle, var_T_reshape: T.handle):
411
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
412
+ n = T.int64()
413
+ A = T.match_buffer(var_A, (T.int64(1), n), "int32")
414
+ T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
415
+ # with T.block("root"):
416
+ for ax0 in range(n):
417
+ with T.block("T_reshape"):
418
+ v_ax0 = T.axis.spatial(n, ax0)
419
+ T.reads(A[T.int64(0), v_ax0 % n])
420
+ T.writes(T_reshape[v_ax0])
421
+ T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n]
422
+
423
+ @T.prim_func
424
+ def reshape6(var_A: T.handle, var_T_reshape: T.handle):
425
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
426
+ n = T.int64()
427
+ A = T.match_buffer(var_A, (n, T.int64(4096)), "float16")
428
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
429
+ # with T.block("root"):
430
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
431
+ with T.block("T_reshape"):
432
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
433
+ T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)])
434
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
435
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]
436
+
437
+ @T.prim_func
438
+ def reshape7(var_A: T.handle, var_T_reshape: T.handle):
439
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
440
+ n = T.int64()
441
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
442
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
443
+ # with T.block("root"):
444
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
445
+ with T.block("T_reshape"):
446
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
447
+ T.reads(A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
448
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
449
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
450
+
451
+ @T.prim_func
452
+ def reshape8(var_A: T.handle, var_T_reshape: T.handle):
453
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
454
+ n = T.int64()
455
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
456
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
457
+ # with T.block("root"):
458
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
459
+ with T.block("T_reshape"):
460
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
461
+ T.reads(A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
462
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
463
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
464
+
465
+ @T.prim_func
466
+ def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle):
467
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
468
+ n = T.int64()
469
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
470
+ rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
471
+ # with T.block("root"):
472
+ Ared_temp = T.alloc_buffer((T.int64(1), n))
473
+ for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
474
+ with T.block("Ared_temp"):
475
+ v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
476
+ T.reads(A[v_bsz, v_i, v_k])
477
+ T.writes(Ared_temp[v_bsz, v_i])
478
+ with T.init():
479
+ Ared_temp[v_bsz, v_i] = T.float32(0)
480
+ Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
481
+ for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
482
+ with T.block("rms_norm"):
483
+ v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
484
+ T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
485
+ T.writes(rms_norm_1[v_bsz, v_i, v_k])
486
+ rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
487
+
488
+ @T.prim_func
489
+ def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, m: T.int64):
490
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
491
+ n = T.int64()
492
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
493
+ rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
494
+ # with T.block("root"):
495
+ for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
496
+ with T.block("rotary"):
497
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
498
+ T.reads(B[m + v_i1 - n, v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[m + v_i1 - n, v_i3])
499
+ T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
500
+ rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[m + v_i1 - n, v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
501
+
502
+ @T.prim_func
503
+ def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
504
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
505
+ n = T.int64()
506
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
507
+ # with T.block("root"):
508
+ for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
509
+ with T.block("slice"):
510
+ v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
511
+ T.reads(A[v_i, n - T.int64(1), v_k])
512
+ T.writes(slice_1[v_i, v_j, v_k])
513
+ slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k]
514
+
515
+ @T.prim_func
516
+ def squeeze1(var_A: T.handle, var_T_squeeze: T.handle):
517
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
518
+ n = T.int64()
519
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
520
+ T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(128)), "float16")
521
+ # with T.block("root"):
522
+ for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(128)):
523
+ with T.block("T_squeeze"):
524
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
525
+ T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
526
+ T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
527
+ T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
528
+
529
+ @T.prim_func
530
+ def take_decode1(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), var_C: T.handle, var_take_decode: T.handle):
531
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
532
+ n = T.int64()
533
+ C = T.match_buffer(var_C, (n,), "int32")
534
+ take_decode = T.match_buffer(var_take_decode, (n, T.int64(4096)), "float16")
535
+ # with T.block("root"):
536
+ for i, j in T.grid(n, T.int64(4096)):
537
+ with T.block("take_decode"):
538
+ v_i, v_j = T.axis.remap("SS", [i, j])
539
+ T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
540
+ T.writes(take_decode[v_i, v_j])
541
+ take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
542
+
543
+ @T.prim_func
544
+ def transpose4(var_A: T.handle, var_T_transpose: T.handle):
545
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
546
+ n = T.int64()
547
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
548
+ T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
549
+ # with T.block("root"):
550
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)):
551
+ with T.block("T_transpose"):
552
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
553
+ T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
554
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
555
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
556
+
557
+ @T.prim_func
558
+ def transpose7(var_A: T.handle, var_T_transpose: T.handle):
559
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
560
+ n = T.int64()
561
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
562
+ T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
563
+ # with T.block("root"):
564
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
565
+ with T.block("T_transpose"):
566
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
567
+ T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
568
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
569
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
570
+ # fmt: on
debug/mod_tir_static.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from tvm.script import ir as I
3
+ from tvm.script import tir as T
4
+
5
+ # fmt: off
6
+ # from tvm.script import ir as I
7
+ # from tvm.script import tir as T
8
+
9
+ @I.ir_module
10
+ class Module:
11
+ @T.prim_func
12
+ def decode5(A: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
13
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
14
+ # with T.block("root"):
15
+ decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
16
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
17
+ with T.block("decode"):
18
+ v_i, v_j = T.axis.remap("SS", [i, j])
19
+ T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j])
20
+ T.writes(decode[v_i, v_j])
21
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
22
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
23
+ with T.block("T_transpose"):
24
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
25
+ T.reads(decode[v_ax1, v_ax0])
26
+ T.writes(T_transpose[v_ax0, v_ax1])
27
+ T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
28
+
29
+ @T.prim_func
30
+ def decode6(A: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")):
31
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
32
+ # with T.block("root"):
33
+ decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
34
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
35
+ with T.block("decode"):
36
+ v_i, v_j = T.axis.remap("SS", [i, j])
37
+ T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j])
38
+ T.writes(decode[v_i, v_j])
39
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
40
+ for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
41
+ with T.block("T_transpose"):
42
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
43
+ T.reads(decode[v_ax1, v_ax0])
44
+ T.writes(T_transpose[v_ax0, v_ax1])
45
+ T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
46
+
47
+ @T.prim_func
48
+ def decode7(A: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")):
49
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
50
+ # with T.block("root"):
51
+ decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
52
+ for i, j in T.grid(T.int64(11008), T.int64(4096)):
53
+ with T.block("decode"):
54
+ v_i, v_j = T.axis.remap("SS", [i, j])
55
+ T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j])
56
+ T.writes(decode[v_i, v_j])
57
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
58
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)):
59
+ with T.block("T_transpose"):
60
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
61
+ T.reads(decode[v_ax1, v_ax0])
62
+ T.writes(T_transpose[v_ax0, v_ax1])
63
+ T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
64
+
65
+ @T.prim_func
66
+ def divide(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
67
+ T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
68
+ # with T.block("root"):
69
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
70
+ with T.block("T_divide"):
71
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
72
+ T.reads(A[v_ax0, v_ax1, v_ax2], B[()])
73
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2])
74
+ T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()]
75
+
76
+ @T.prim_func
77
+ def fused_decode1_fused_matmul4_add(lv26: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv27: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
78
+ T.func_attr({"tir.noalias": T.bool(True)})
79
+ # with T.block("root"):
80
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
81
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
82
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
83
+ with T.block("decode"):
84
+ v_i, v_j = T.axis.remap("SS", [i, j])
85
+ T.reads(lv26[v_i // T.int64(5), v_j], lv27[v_i // T.int64(40), v_j])
86
+ T.writes(var_decode_intermediate[v_i, v_j])
87
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv26[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv27[v_i // T.int64(40), v_j]
88
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
89
+ with T.block("matmul"):
90
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
91
+ T.reads(lv3[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
92
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
93
+ with T.init():
94
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
95
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
96
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
97
+ with T.block("T_add"):
98
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
99
+ T.reads(lv1581[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
100
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
101
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1581[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
102
+
103
+ @T.prim_func
104
+ def fused_decode1_matmul4(lv8: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv9: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
105
+ T.func_attr({"tir.noalias": T.bool(True)})
106
+ # with T.block("root"):
107
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
108
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
109
+ with T.block("decode"):
110
+ v_i, v_j = T.axis.remap("SS", [i, j])
111
+ T.reads(lv8[v_i // T.int64(5), v_j], lv9[v_i // T.int64(40), v_j])
112
+ T.writes(var_decode_intermediate[v_i, v_j])
113
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv8[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv9[v_i // T.int64(40), v_j]
114
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
115
+ with T.block("matmul"):
116
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
117
+ T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
118
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
119
+ with T.init():
120
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
121
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
122
+
123
+ @T.prim_func
124
+ def fused_decode2_fused_matmul6_multiply(lv38: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv39: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
125
+ T.func_attr({"tir.noalias": T.bool(True)})
126
+ # with T.block("root"):
127
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
128
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
129
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
130
+ with T.block("decode"):
131
+ v_i, v_j = T.axis.remap("SS", [i, j])
132
+ T.reads(lv38[v_i // T.int64(5), v_j], lv39[v_i // T.int64(40), v_j])
133
+ T.writes(var_decode_intermediate[v_i, v_j])
134
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv38[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv39[v_i // T.int64(40), v_j]
135
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
136
+ with T.block("matmul"):
137
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
138
+ T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
139
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
140
+ with T.init():
141
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
142
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
143
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
144
+ with T.block("T_multiply"):
145
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
146
+ T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
147
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
148
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
149
+
150
+ @T.prim_func
151
+ def fused_decode2_fused_matmul6_silu(lv32: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv33: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
152
+ T.func_attr({"tir.noalias": T.bool(True)})
153
+ # with T.block("root"):
154
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
155
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
156
+ compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
157
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
158
+ with T.block("decode"):
159
+ v_i, v_j = T.axis.remap("SS", [i, j])
160
+ T.reads(lv32[v_i // T.int64(5), v_j], lv33[v_i // T.int64(40), v_j])
161
+ T.writes(var_decode_intermediate[v_i, v_j])
162
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv32[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv33[v_i // T.int64(40), v_j]
163
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
164
+ with T.block("matmul"):
165
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
166
+ T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
167
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
168
+ with T.init():
169
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
170
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
171
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
172
+ with T.block("compute"):
173
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
174
+ T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
175
+ T.writes(compute[v_i0, v_i1, v_i2])
176
+ compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2])
177
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
178
+ with T.block("T_multiply"):
179
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
180
+ T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
181
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
182
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
183
+
184
+ @T.prim_func
185
+ def fused_decode3_fused_matmul7_add(lv44: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv45: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
186
+ T.func_attr({"tir.noalias": T.bool(True)})
187
+ # with T.block("root"):
188
+ var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
189
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
190
+ for i, j in T.grid(T.int64(11008), T.int64(4096)):
191
+ with T.block("decode"):
192
+ v_i, v_j = T.axis.remap("SS", [i, j])
193
+ T.reads(lv44[v_i // T.int64(5), v_j], lv45[v_i // T.int64(40), v_j])
194
+ T.writes(var_decode_intermediate[v_i, v_j])
195
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv44[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv45[v_i // T.int64(40), v_j]
196
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)):
197
+ with T.block("matmul"):
198
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
199
+ T.reads(lv6[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
200
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
201
+ with T.init():
202
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
203
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv6[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
204
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
205
+ with T.block("T_add"):
206
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
207
+ T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
208
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
209
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
210
+
211
+ @T.prim_func
212
+ def fused_decode4_fused_matmul8_cast2(lv2931: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
213
+ T.func_attr({"tir.noalias": T.bool(True)})
214
+ # with T.block("root"):
215
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16")
216
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")
217
+ for i, j in T.grid(T.int64(4096), T.int64(32000)):
218
+ with T.block("decode"):
219
+ v_i, v_j = T.axis.remap("SS", [i, j])
220
+ T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j])
221
+ T.writes(var_decode_intermediate[v_i, v_j])
222
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv2931[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j]
223
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):
224
+ with T.block("matmul"):
225
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
226
+ T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
227
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
228
+ with T.init():
229
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
230
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
231
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
232
+ with T.block("compute"):
233
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
234
+ T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
235
+ T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
236
+ p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2])
237
+
238
+ @T.prim_func
239
+ def fused_reshape2_squeeze(lv1591: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
240
+ T.func_attr({"tir.noalias": T.bool(True)})
241
+ # with T.block("root"):
242
+ var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
243
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
244
+ with T.block("T_reshape"):
245
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
246
+ T.reads(lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
247
+ T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
248
+ var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
249
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
250
+ with T.block("T_squeeze"):
251
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
252
+ T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2])
253
+ T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2])
254
+ var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]
255
+
256
+ @T.prim_func
257
+ def fused_transpose5_reshape4(lv1616: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
258
+ T.func_attr({"tir.noalias": T.bool(True)})
259
+ # with T.block("root"):
260
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
261
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
262
+ with T.block("T_transpose"):
263
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
264
+ T.reads(lv1616[v_ax0, v_ax2, v_ax1, v_ax3])
265
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
266
+ var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1616[v_ax0, v_ax2, v_ax1, v_ax3]
267
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
268
+ with T.block("T_reshape"):
269
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
270
+ T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
271
+ T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2])
272
+ var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
273
+
274
+ @T.prim_func
275
+ def reshape(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")):
276
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
277
+ # with T.block("root"):
278
+ for ax0 in range(T.int64(1)):
279
+ with T.block("T_reshape"):
280
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
281
+ T.reads(A[T.int64(0), T.int64(0)])
282
+ T.writes(T_reshape[v_ax0])
283
+ T_reshape[v_ax0] = A[T.int64(0), T.int64(0)]
284
+
285
+ @T.prim_func
286
+ def reshape1(A: T.Buffer((T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
287
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
288
+ # with T.block("root"):
289
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
290
+ with T.block("T_reshape"):
291
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
292
+ T.reads(A[T.int64(0), v_ax2 % T.int64(4096)])
293
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
294
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(4096)]
295
+
296
+ @T.prim_func
297
+ def reshape2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")):
298
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
299
+ # with T.block("root"):
300
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
301
+ with T.block("T_reshape"):
302
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
303
+ T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
304
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
305
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
306
+
307
+ @T.prim_func
308
+ def rms_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), B: T.Buffer((T.int64(4096),), "float16"), rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
309
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
310
+ # with T.block("root"):
311
+ Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1)))
312
+ for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
313
+ with T.block("Ared_temp"):
314
+ v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
315
+ T.reads(A[v_bsz, v_i, v_k])
316
+ T.writes(Ared_temp[v_bsz, v_i])
317
+ with T.init():
318
+ Ared_temp[v_bsz, v_i] = T.float32(0)
319
+ Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
320
+ for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
321
+ with T.block("rms_norm"):
322
+ v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
323
+ T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
324
+ T.writes(rms_norm[v_bsz, v_i, v_k])
325
+ rms_norm[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
326
+
327
+ @T.prim_func
328
+ def rotary_embedding1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), n: T.int64):
329
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
330
+ # with T.block("root"):
331
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
332
+ with T.block("rotary"):
333
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
334
+ T.reads(B[n + v_i1 - T.int64(1), v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[n + v_i1 - T.int64(1), v_i3])
335
+ T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
336
+ rotary[v_i0, v_i1, v_i2, v_i3] = B[n + v_i1 - T.int64(1), v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[n + v_i1 - T.int64(1), v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
337
+
338
+ @T.prim_func
339
+ def slice1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), slice: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
340
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
341
+ # with T.block("root"):
342
+ for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
343
+ with T.block("slice"):
344
+ v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
345
+ T.reads(A[v_i, T.int64(0), v_k])
346
+ T.writes(slice[v_i, v_j, v_k])
347
+ slice[v_i, v_j, v_k] = A[v_i, T.int64(0), v_k]
348
+
349
+ @T.prim_func
350
+ def softmax(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
351
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
352
+ # with T.block("root"):
353
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1)))
354
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)))
355
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1)))
356
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
357
+ with T.block("T_softmax_maxelem"):
358
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
359
+ T.reads(A[v_i0, v_i1, v_k])
360
+ T.writes(T_softmax_maxelem[v_i0, v_i1])
361
+ with T.init():
362
+ T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38)
363
+ T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k])
364
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
365
+ with T.block("T_softmax_exp"):
366
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
367
+ T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
368
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
369
+ T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
370
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
371
+ with T.block("T_softmax_expsum"):
372
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
373
+ T.reads(T_softmax_exp[v_i0, v_i1, v_k])
374
+ T.writes(T_softmax_expsum[v_i0, v_i1])
375
+ with T.init():
376
+ T_softmax_expsum[v_i0, v_i1] = T.float32(0)
377
+ T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
378
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
379
+ with T.block("T_softmax_norm"):
380
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
381
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1])
382
+ T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
383
+ T.block_attr({"axis": 2})
384
+ T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
385
+
386
+ @T.prim_func
387
+ def squeeze(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
388
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
389
+ # with T.block("root"):
390
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
391
+ with T.block("T_squeeze"):
392
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
393
+ T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
394
+ T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
395
+ T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
396
+
397
+ @T.prim_func
398
+ def take_decode(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode_1: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
399
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
400
+ # with T.block("root"):
401
+ for i, j in T.grid(T.int64(1), T.int64(4096)):
402
+ with T.block("take_decode"):
403
+ v_i, v_j = T.axis.remap("SS", [i, j])
404
+ T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
405
+ T.writes(take_decode_1[v_i, v_j])
406
+ take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
407
+
408
+ @T.prim_func
409
+ def transpose3(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
410
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
411
+ # with T.block("root"):
412
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128)):
413
+ with T.block("T_transpose"):
414
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
415
+ T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
416
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
417
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
418
+ # fmt: on