Commit
·
7ba6b67
1
Parent(s):
be87a49
Upload 2 files
Browse files- debug/mod_tir_dynamic.py +570 -0
- debug/mod_tir_static.py +418 -0
debug/mod_tir_dynamic.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from tvm.script import ir as I
|
3 |
+
from tvm.script import tir as T
|
4 |
+
|
5 |
+
# fmt: off
|
6 |
+
# from tvm.script import ir as I
|
7 |
+
# from tvm.script import tir as T
|
8 |
+
|
9 |
+
@I.ir_module
|
10 |
+
class Module:
|
11 |
+
@T.prim_func
|
12 |
+
def NT_matmul1(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle):
|
13 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
14 |
+
n = T.int64()
|
15 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
16 |
+
NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096)), "float16")
|
17 |
+
# with T.block("root"):
|
18 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
|
19 |
+
with T.block("NT_matmul"):
|
20 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
21 |
+
T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
|
22 |
+
T.writes(NT_matmul[v_i0, v_i1, v_i2])
|
23 |
+
with T.init():
|
24 |
+
NT_matmul[v_i0, v_i1, v_i2] = T.float16(0)
|
25 |
+
NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k]
|
26 |
+
|
27 |
+
@T.prim_func
|
28 |
+
def extend_te(var_A: T.handle, var_concat_te: T.handle):
|
29 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
30 |
+
n = T.int64()
|
31 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16")
|
32 |
+
m = T.int64()
|
33 |
+
concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16")
|
34 |
+
# with T.block("root"):
|
35 |
+
for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m):
|
36 |
+
with T.block("concat_te"):
|
37 |
+
v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j])
|
38 |
+
T.reads(A[v_b, v__, v_i, v_j + n - m])
|
39 |
+
T.writes(concat_te[v_b, v__, v_i, v_j])
|
40 |
+
concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m])
|
41 |
+
|
42 |
+
@T.prim_func
|
43 |
+
def full(var_T_full: T.handle):
|
44 |
+
T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
|
45 |
+
n = T.int64()
|
46 |
+
T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
|
47 |
+
# with T.block("root"):
|
48 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n):
|
49 |
+
with T.block("T_full"):
|
50 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
51 |
+
T.reads()
|
52 |
+
T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
|
53 |
+
T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504)
|
54 |
+
|
55 |
+
@T.prim_func
|
56 |
+
def fused_NT_matmul1_add1(p_lv41: T.handle, lv1386: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle):
|
57 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
58 |
+
n = T.int64()
|
59 |
+
lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16")
|
60 |
+
lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16")
|
61 |
+
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
|
62 |
+
# with T.block("root"):
|
63 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
|
64 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
|
65 |
+
with T.block("NT_matmul"):
|
66 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
67 |
+
T.reads(lv41[v_i0, v_i1, v_k], lv1386[v_i2, v_k])
|
68 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
69 |
+
with T.init():
|
70 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
71 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv41[v_i0, v_i1, v_k] * lv1386[v_i2, v_k]
|
72 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
73 |
+
with T.block("T_add"):
|
74 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
75 |
+
T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
76 |
+
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
|
77 |
+
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
78 |
+
|
79 |
+
@T.prim_func
|
80 |
+
def fused_NT_matmul2_divide2_maximum1_minimum1_cast3(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
|
81 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
82 |
+
n = T.int64()
|
83 |
+
lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
84 |
+
m = T.int64()
|
85 |
+
lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
|
86 |
+
lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16")
|
87 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m))
|
88 |
+
# with T.block("root"):
|
89 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
90 |
+
var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
91 |
+
var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
92 |
+
var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
93 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)):
|
94 |
+
with T.block("NT_matmul"):
|
95 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
96 |
+
T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k])
|
97 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
|
98 |
+
with T.init():
|
99 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
100 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k]
|
101 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
|
102 |
+
with T.block("T_divide"):
|
103 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
104 |
+
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
105 |
+
T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
106 |
+
var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
|
107 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
|
108 |
+
with T.block("T_maximum"):
|
109 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
110 |
+
T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
111 |
+
T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
112 |
+
var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
|
113 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
|
114 |
+
with T.block("T_minimum"):
|
115 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
116 |
+
T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
|
117 |
+
T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
118 |
+
var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
|
119 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
120 |
+
with T.block("compute"):
|
121 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
122 |
+
T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
123 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
124 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
125 |
+
|
126 |
+
@T.prim_func
|
127 |
+
def fused_NT_matmul3_multiply1(p_lv45: T.handle, lv1400: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_lv50: T.handle, p_output0: T.handle):
|
128 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
129 |
+
n = T.int64()
|
130 |
+
lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
|
131 |
+
lv50 = T.match_buffer(p_lv50, (T.int64(1), n, T.int64(11008)), "float16")
|
132 |
+
var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
|
133 |
+
# with T.block("root"):
|
134 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
|
135 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
|
136 |
+
with T.block("NT_matmul"):
|
137 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
138 |
+
T.reads(lv45[v_i0, v_i1, v_k], lv1400[v_i2, v_k])
|
139 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
140 |
+
with T.init():
|
141 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
142 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv1400[v_i2, v_k]
|
143 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
|
144 |
+
with T.block("T_multiply"):
|
145 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
146 |
+
T.reads(lv50[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
147 |
+
T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
|
148 |
+
var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv50[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
149 |
+
|
150 |
+
@T.prim_func
|
151 |
+
def fused_NT_matmul3_silu1(p_lv45: T.handle, lv1393: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_output0: T.handle):
|
152 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
153 |
+
n = T.int64()
|
154 |
+
lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
|
155 |
+
var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
|
156 |
+
# with T.block("root"):
|
157 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
|
158 |
+
compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
|
159 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
|
160 |
+
with T.block("NT_matmul"):
|
161 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
162 |
+
T.reads(lv45[v_i0, v_i1, v_k], lv1393[v_i2, v_k])
|
163 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
164 |
+
with T.init():
|
165 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
166 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv1393[v_i2, v_k]
|
167 |
+
for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)):
|
168 |
+
with T.block("compute"):
|
169 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
170 |
+
T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
171 |
+
T.writes(compute[v_i0, v_i1, v_i2])
|
172 |
+
compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
173 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
|
174 |
+
with T.block("T_multiply"):
|
175 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
176 |
+
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
|
177 |
+
T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
|
178 |
+
var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
|
179 |
+
|
180 |
+
@T.prim_func
|
181 |
+
def fused_NT_matmul4_add1(p_lv51: T.handle, lv1407: T.Buffer((T.int64(4096), T.int64(11008)), "float16"), p_lv44: T.handle, p_output0: T.handle):
|
182 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
183 |
+
n = T.int64()
|
184 |
+
lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(11008)), "float16")
|
185 |
+
lv44 = T.match_buffer(p_lv44, (T.int64(1), n, T.int64(4096)), "float16")
|
186 |
+
var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
|
187 |
+
# with T.block("root"):
|
188 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
|
189 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):
|
190 |
+
with T.block("NT_matmul"):
|
191 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
192 |
+
T.reads(lv51[v_i0, v_i1, v_k], lv1407[v_i2, v_k])
|
193 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
194 |
+
with T.init():
|
195 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
196 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv1407[v_i2, v_k]
|
197 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
198 |
+
with T.block("T_add"):
|
199 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
200 |
+
T.reads(lv44[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
201 |
+
T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
|
202 |
+
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv44[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
203 |
+
|
204 |
+
@T.prim_func
|
205 |
+
def fused_NT_matmul_divide1_maximum_minimum_cast(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle):
|
206 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
207 |
+
n = T.int64()
|
208 |
+
lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
209 |
+
lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
|
210 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n))
|
211 |
+
# with T.block("root"):
|
212 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
213 |
+
var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
214 |
+
var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
215 |
+
var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
216 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):
|
217 |
+
with T.block("NT_matmul"):
|
218 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
219 |
+
T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k])
|
220 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
|
221 |
+
with T.init():
|
222 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
223 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k]
|
224 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
225 |
+
with T.block("T_divide"):
|
226 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
227 |
+
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
228 |
+
T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
229 |
+
var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
|
230 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
231 |
+
with T.block("T_maximum"):
|
232 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
233 |
+
T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
234 |
+
T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
235 |
+
var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
|
236 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
237 |
+
with T.block("T_minimum"):
|
238 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
239 |
+
T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
|
240 |
+
T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
241 |
+
var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
|
242 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
243 |
+
with T.block("compute"):
|
244 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
245 |
+
T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
246 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
247 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
248 |
+
|
249 |
+
@T.prim_func
|
250 |
+
def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64):
|
251 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
252 |
+
var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16")
|
253 |
+
# with T.block("root"):
|
254 |
+
var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16")
|
255 |
+
for i, j in T.grid(n, n):
|
256 |
+
with T.block("make_diag_mask_te"):
|
257 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
258 |
+
T.reads()
|
259 |
+
T.writes(var_make_diag_mask_te_intermediate[v_i, v_j])
|
260 |
+
var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504))
|
261 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n):
|
262 |
+
with T.block("T_broadcast_to"):
|
263 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
264 |
+
T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3])
|
265 |
+
T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
266 |
+
var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3]
|
267 |
+
|
268 |
+
@T.prim_func
|
269 |
+
def fused_softmax1_cast1(p_lv1613: T.handle, p_output0: T.handle):
|
270 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
271 |
+
n = T.int64()
|
272 |
+
lv1613 = T.match_buffer(p_lv1613, (T.int64(1), T.int64(32), T.int64(1), n))
|
273 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
274 |
+
# with T.block("root"):
|
275 |
+
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
|
276 |
+
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
|
277 |
+
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
|
278 |
+
var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
|
279 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
280 |
+
with T.block("T_softmax_maxelem"):
|
281 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
282 |
+
T.reads(lv1613[v_i0, v_i1, v_i2, v_k])
|
283 |
+
T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
|
284 |
+
with T.init():
|
285 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
|
286 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1613[v_i0, v_i1, v_i2, v_k])
|
287 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
288 |
+
with T.block("T_softmax_exp"):
|
289 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
290 |
+
T.reads(lv1613[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
|
291 |
+
T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
|
292 |
+
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1613[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
|
293 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
294 |
+
with T.block("T_softmax_expsum"):
|
295 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
296 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
|
297 |
+
T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
|
298 |
+
with T.init():
|
299 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
|
300 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
|
301 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
302 |
+
with T.block("T_softmax_norm"):
|
303 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
304 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
|
305 |
+
T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
306 |
+
T.block_attr({"axis": 3})
|
307 |
+
var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
|
308 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
309 |
+
with T.block("compute"):
|
310 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
311 |
+
T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
312 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
313 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
314 |
+
|
315 |
+
@T.prim_func
|
316 |
+
def fused_softmax2_cast4(p_lv36: T.handle, p_output0: T.handle):
|
317 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
318 |
+
n, m = T.int64(), T.int64()
|
319 |
+
lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, m))
|
320 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16")
|
321 |
+
# with T.block("root"):
|
322 |
+
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n))
|
323 |
+
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
|
324 |
+
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n))
|
325 |
+
var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
|
326 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
|
327 |
+
with T.block("T_softmax_maxelem"):
|
328 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
329 |
+
T.reads(lv36[v_i0, v_i1, v_i2, v_k])
|
330 |
+
T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
|
331 |
+
with T.init():
|
332 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
|
333 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv36[v_i0, v_i1, v_i2, v_k])
|
334 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
335 |
+
with T.block("T_softmax_exp"):
|
336 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
337 |
+
T.reads(lv36[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
|
338 |
+
T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
|
339 |
+
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv36[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
|
340 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
|
341 |
+
with T.block("T_softmax_expsum"):
|
342 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
343 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
|
344 |
+
T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
|
345 |
+
with T.init():
|
346 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
|
347 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
|
348 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
349 |
+
with T.block("T_softmax_norm"):
|
350 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
351 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
|
352 |
+
T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
353 |
+
T.block_attr({"axis": 3})
|
354 |
+
var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
|
355 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
356 |
+
with T.block("compute"):
|
357 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
358 |
+
T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
359 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
360 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
361 |
+
|
362 |
+
@T.prim_func
|
363 |
+
def matmul10(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
|
364 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
365 |
+
n, m = T.int64(), T.int64()
|
366 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16")
|
367 |
+
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
|
368 |
+
matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
369 |
+
# with T.block("root"):
|
370 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m):
|
371 |
+
with T.block("matmul"):
|
372 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
373 |
+
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
|
374 |
+
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
|
375 |
+
with T.init():
|
376 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
377 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
|
378 |
+
|
379 |
+
@T.prim_func
|
380 |
+
def matmul5(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
|
381 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
382 |
+
n = T.int64()
|
383 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
384 |
+
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
385 |
+
# with T.block("root"):
|
386 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
|
387 |
+
with T.block("matmul"):
|
388 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
389 |
+
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
|
390 |
+
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
|
391 |
+
with T.init():
|
392 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
393 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
|
394 |
+
|
395 |
+
@T.prim_func
|
396 |
+
def reshape3(var_A: T.handle, var_T_reshape: T.handle):
|
397 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
398 |
+
n = T.int64()
|
399 |
+
A = T.match_buffer(var_A, (n, T.int64(32), T.int64(128)), "float16")
|
400 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
401 |
+
# with T.block("root"):
|
402 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
403 |
+
with T.block("T_reshape"):
|
404 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
405 |
+
T.reads(A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)])
|
406 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
|
407 |
+
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)]
|
408 |
+
|
409 |
+
@T.prim_func
|
410 |
+
def reshape5(var_A: T.handle, var_T_reshape: T.handle):
|
411 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
412 |
+
n = T.int64()
|
413 |
+
A = T.match_buffer(var_A, (T.int64(1), n), "int32")
|
414 |
+
T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
|
415 |
+
# with T.block("root"):
|
416 |
+
for ax0 in range(n):
|
417 |
+
with T.block("T_reshape"):
|
418 |
+
v_ax0 = T.axis.spatial(n, ax0)
|
419 |
+
T.reads(A[T.int64(0), v_ax0 % n])
|
420 |
+
T.writes(T_reshape[v_ax0])
|
421 |
+
T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n]
|
422 |
+
|
423 |
+
@T.prim_func
|
424 |
+
def reshape6(var_A: T.handle, var_T_reshape: T.handle):
|
425 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
426 |
+
n = T.int64()
|
427 |
+
A = T.match_buffer(var_A, (n, T.int64(4096)), "float16")
|
428 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
|
429 |
+
# with T.block("root"):
|
430 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
431 |
+
with T.block("T_reshape"):
|
432 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
433 |
+
T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)])
|
434 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
|
435 |
+
T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]
|
436 |
+
|
437 |
+
@T.prim_func
|
438 |
+
def reshape7(var_A: T.handle, var_T_reshape: T.handle):
|
439 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
440 |
+
n = T.int64()
|
441 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
442 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
443 |
+
# with T.block("root"):
|
444 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
445 |
+
with T.block("T_reshape"):
|
446 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
447 |
+
T.reads(A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
|
448 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
|
449 |
+
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
|
450 |
+
|
451 |
+
@T.prim_func
|
452 |
+
def reshape8(var_A: T.handle, var_T_reshape: T.handle):
|
453 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
454 |
+
n = T.int64()
|
455 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
456 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
|
457 |
+
# with T.block("root"):
|
458 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
459 |
+
with T.block("T_reshape"):
|
460 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
461 |
+
T.reads(A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
|
462 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
|
463 |
+
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
|
464 |
+
|
465 |
+
@T.prim_func
|
466 |
+
def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle):
|
467 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
468 |
+
n = T.int64()
|
469 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
470 |
+
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
|
471 |
+
# with T.block("root"):
|
472 |
+
Ared_temp = T.alloc_buffer((T.int64(1), n))
|
473 |
+
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
|
474 |
+
with T.block("Ared_temp"):
|
475 |
+
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
|
476 |
+
T.reads(A[v_bsz, v_i, v_k])
|
477 |
+
T.writes(Ared_temp[v_bsz, v_i])
|
478 |
+
with T.init():
|
479 |
+
Ared_temp[v_bsz, v_i] = T.float32(0)
|
480 |
+
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
|
481 |
+
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
|
482 |
+
with T.block("rms_norm"):
|
483 |
+
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
|
484 |
+
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
|
485 |
+
T.writes(rms_norm_1[v_bsz, v_i, v_k])
|
486 |
+
rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
|
487 |
+
|
488 |
+
@T.prim_func
|
489 |
+
def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, m: T.int64):
|
490 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
491 |
+
n = T.int64()
|
492 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
493 |
+
rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
494 |
+
# with T.block("root"):
|
495 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
496 |
+
with T.block("rotary"):
|
497 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
498 |
+
T.reads(B[m + v_i1 - n, v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[m + v_i1 - n, v_i3])
|
499 |
+
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
|
500 |
+
rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[m + v_i1 - n, v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
|
501 |
+
|
502 |
+
@T.prim_func
|
503 |
+
def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
504 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
505 |
+
n = T.int64()
|
506 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
507 |
+
# with T.block("root"):
|
508 |
+
for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
509 |
+
with T.block("slice"):
|
510 |
+
v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
|
511 |
+
T.reads(A[v_i, n - T.int64(1), v_k])
|
512 |
+
T.writes(slice_1[v_i, v_j, v_k])
|
513 |
+
slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k]
|
514 |
+
|
515 |
+
@T.prim_func
|
516 |
+
def squeeze1(var_A: T.handle, var_T_squeeze: T.handle):
|
517 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
518 |
+
n = T.int64()
|
519 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
520 |
+
T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(128)), "float16")
|
521 |
+
# with T.block("root"):
|
522 |
+
for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(128)):
|
523 |
+
with T.block("T_squeeze"):
|
524 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
525 |
+
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
|
526 |
+
T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
|
527 |
+
T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
|
528 |
+
|
529 |
+
@T.prim_func
|
530 |
+
def take_decode1(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), var_C: T.handle, var_take_decode: T.handle):
|
531 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
532 |
+
n = T.int64()
|
533 |
+
C = T.match_buffer(var_C, (n,), "int32")
|
534 |
+
take_decode = T.match_buffer(var_take_decode, (n, T.int64(4096)), "float16")
|
535 |
+
# with T.block("root"):
|
536 |
+
for i, j in T.grid(n, T.int64(4096)):
|
537 |
+
with T.block("take_decode"):
|
538 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
539 |
+
T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
|
540 |
+
T.writes(take_decode[v_i, v_j])
|
541 |
+
take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
|
542 |
+
|
543 |
+
@T.prim_func
|
544 |
+
def transpose4(var_A: T.handle, var_T_transpose: T.handle):
|
545 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
546 |
+
n = T.int64()
|
547 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
548 |
+
T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
549 |
+
# with T.block("root"):
|
550 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)):
|
551 |
+
with T.block("T_transpose"):
|
552 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
553 |
+
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
|
554 |
+
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
|
555 |
+
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
|
556 |
+
|
557 |
+
@T.prim_func
|
558 |
+
def transpose7(var_A: T.handle, var_T_transpose: T.handle):
|
559 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
560 |
+
n = T.int64()
|
561 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
562 |
+
T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
563 |
+
# with T.block("root"):
|
564 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
565 |
+
with T.block("T_transpose"):
|
566 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
567 |
+
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
|
568 |
+
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
|
569 |
+
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
|
570 |
+
# fmt: on
|
debug/mod_tir_static.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from tvm.script import ir as I
|
3 |
+
from tvm.script import tir as T
|
4 |
+
|
5 |
+
# fmt: off
|
6 |
+
# from tvm.script import ir as I
|
7 |
+
# from tvm.script import tir as T
|
8 |
+
|
9 |
+
@I.ir_module
|
10 |
+
class Module:
|
11 |
+
@T.prim_func
|
12 |
+
def decode5(A: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
|
13 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
14 |
+
# with T.block("root"):
|
15 |
+
decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
16 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
17 |
+
with T.block("decode"):
|
18 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
19 |
+
T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j])
|
20 |
+
T.writes(decode[v_i, v_j])
|
21 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
|
22 |
+
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
|
23 |
+
with T.block("T_transpose"):
|
24 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
25 |
+
T.reads(decode[v_ax1, v_ax0])
|
26 |
+
T.writes(T_transpose[v_ax0, v_ax1])
|
27 |
+
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
28 |
+
|
29 |
+
@T.prim_func
|
30 |
+
def decode6(A: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), B: T.Buffer((T.int64(103), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")):
|
31 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
32 |
+
# with T.block("root"):
|
33 |
+
decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
34 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
35 |
+
with T.block("decode"):
|
36 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
37 |
+
T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j])
|
38 |
+
T.writes(decode[v_i, v_j])
|
39 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
|
40 |
+
for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
|
41 |
+
with T.block("T_transpose"):
|
42 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
43 |
+
T.reads(decode[v_ax1, v_ax0])
|
44 |
+
T.writes(T_transpose[v_ax0, v_ax1])
|
45 |
+
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
46 |
+
|
47 |
+
@T.prim_func
|
48 |
+
def decode7(A: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), B: T.Buffer((T.int64(276), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")):
|
49 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
50 |
+
# with T.block("root"):
|
51 |
+
decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
|
52 |
+
for i, j in T.grid(T.int64(11008), T.int64(4096)):
|
53 |
+
with T.block("decode"):
|
54 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
55 |
+
T.reads(A[v_i // T.int64(5), v_j], B[v_i // T.int64(40), v_j])
|
56 |
+
T.writes(decode[v_i, v_j])
|
57 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // T.int64(40), v_j]
|
58 |
+
for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)):
|
59 |
+
with T.block("T_transpose"):
|
60 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
61 |
+
T.reads(decode[v_ax1, v_ax0])
|
62 |
+
T.writes(T_transpose[v_ax0, v_ax1])
|
63 |
+
T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
64 |
+
|
65 |
+
@T.prim_func
|
66 |
+
def divide(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
|
67 |
+
T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
|
68 |
+
# with T.block("root"):
|
69 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
70 |
+
with T.block("T_divide"):
|
71 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
72 |
+
T.reads(A[v_ax0, v_ax1, v_ax2], B[()])
|
73 |
+
T.writes(T_divide[v_ax0, v_ax1, v_ax2])
|
74 |
+
T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()]
|
75 |
+
|
76 |
+
@T.prim_func
|
77 |
+
def fused_decode1_fused_matmul4_add(lv26: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv27: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
78 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
79 |
+
# with T.block("root"):
|
80 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
81 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
|
82 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
83 |
+
with T.block("decode"):
|
84 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
85 |
+
T.reads(lv26[v_i // T.int64(5), v_j], lv27[v_i // T.int64(40), v_j])
|
86 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
87 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv26[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv27[v_i // T.int64(40), v_j]
|
88 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
|
89 |
+
with T.block("matmul"):
|
90 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
91 |
+
T.reads(lv3[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
92 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
93 |
+
with T.init():
|
94 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
95 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
96 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
97 |
+
with T.block("T_add"):
|
98 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
99 |
+
T.reads(lv1581[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
100 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
101 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1581[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
102 |
+
|
103 |
+
@T.prim_func
|
104 |
+
def fused_decode1_matmul4(lv8: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv9: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
105 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
106 |
+
# with T.block("root"):
|
107 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
108 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
109 |
+
with T.block("decode"):
|
110 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
111 |
+
T.reads(lv8[v_i // T.int64(5), v_j], lv9[v_i // T.int64(40), v_j])
|
112 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
113 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv8[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv9[v_i // T.int64(40), v_j]
|
114 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
|
115 |
+
with T.block("matmul"):
|
116 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
117 |
+
T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
118 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
119 |
+
with T.init():
|
120 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
121 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
122 |
+
|
123 |
+
@T.prim_func
|
124 |
+
def fused_decode2_fused_matmul6_multiply(lv38: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv39: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
|
125 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
126 |
+
# with T.block("root"):
|
127 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
128 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
|
129 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
130 |
+
with T.block("decode"):
|
131 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
132 |
+
T.reads(lv38[v_i // T.int64(5), v_j], lv39[v_i // T.int64(40), v_j])
|
133 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
134 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv38[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv39[v_i // T.int64(40), v_j]
|
135 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
|
136 |
+
with T.block("matmul"):
|
137 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
138 |
+
T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
139 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
140 |
+
with T.init():
|
141 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
142 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
143 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
|
144 |
+
with T.block("T_multiply"):
|
145 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
146 |
+
T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
147 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
148 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
149 |
+
|
150 |
+
@T.prim_func
|
151 |
+
def fused_decode2_fused_matmul6_silu(lv32: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv33: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
|
152 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
153 |
+
# with T.block("root"):
|
154 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
155 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
|
156 |
+
compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
|
157 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
158 |
+
with T.block("decode"):
|
159 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
160 |
+
T.reads(lv32[v_i // T.int64(5), v_j], lv33[v_i // T.int64(40), v_j])
|
161 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
162 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv32[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv33[v_i // T.int64(40), v_j]
|
163 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
|
164 |
+
with T.block("matmul"):
|
165 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
166 |
+
T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
167 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
168 |
+
with T.init():
|
169 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
170 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
171 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
|
172 |
+
with T.block("compute"):
|
173 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
174 |
+
T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
175 |
+
T.writes(compute[v_i0, v_i1, v_i2])
|
176 |
+
compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
177 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
|
178 |
+
with T.block("T_multiply"):
|
179 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
180 |
+
T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
|
181 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
182 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
|
183 |
+
|
184 |
+
@T.prim_func
|
185 |
+
def fused_decode3_fused_matmul7_add(lv44: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv45: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
186 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
187 |
+
# with T.block("root"):
|
188 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
|
189 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
|
190 |
+
for i, j in T.grid(T.int64(11008), T.int64(4096)):
|
191 |
+
with T.block("decode"):
|
192 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
193 |
+
T.reads(lv44[v_i // T.int64(5), v_j], lv45[v_i // T.int64(40), v_j])
|
194 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
195 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv44[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv45[v_i // T.int64(40), v_j]
|
196 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)):
|
197 |
+
with T.block("matmul"):
|
198 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
199 |
+
T.reads(lv6[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
200 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
201 |
+
with T.init():
|
202 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
203 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv6[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
204 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
205 |
+
with T.block("T_add"):
|
206 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
207 |
+
T.reads(lv4[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
208 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
209 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv4[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
210 |
+
|
211 |
+
@T.prim_func
|
212 |
+
def fused_decode4_fused_matmul8_cast2(lv2931: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv1575: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
|
213 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
214 |
+
# with T.block("root"):
|
215 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16")
|
216 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")
|
217 |
+
for i, j in T.grid(T.int64(4096), T.int64(32000)):
|
218 |
+
with T.block("decode"):
|
219 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
220 |
+
T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j])
|
221 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
222 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv2931[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j]
|
223 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):
|
224 |
+
with T.block("matmul"):
|
225 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
226 |
+
T.reads(lv1575[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
227 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
228 |
+
with T.init():
|
229 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
230 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1575[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
231 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
232 |
+
with T.block("compute"):
|
233 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
234 |
+
T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
235 |
+
T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
|
236 |
+
p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2])
|
237 |
+
|
238 |
+
@T.prim_func
|
239 |
+
def fused_reshape2_squeeze(lv1591: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
|
240 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
241 |
+
# with T.block("root"):
|
242 |
+
var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
|
243 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
244 |
+
with T.block("T_reshape"):
|
245 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
246 |
+
T.reads(lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
|
247 |
+
T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
248 |
+
var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
|
249 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
|
250 |
+
with T.block("T_squeeze"):
|
251 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
252 |
+
T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2])
|
253 |
+
T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2])
|
254 |
+
var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]
|
255 |
+
|
256 |
+
@T.prim_func
|
257 |
+
def fused_transpose5_reshape4(lv1616: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
258 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
259 |
+
# with T.block("root"):
|
260 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
|
261 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
262 |
+
with T.block("T_transpose"):
|
263 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
264 |
+
T.reads(lv1616[v_ax0, v_ax2, v_ax1, v_ax3])
|
265 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
266 |
+
var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1616[v_ax0, v_ax2, v_ax1, v_ax3]
|
267 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
268 |
+
with T.block("T_reshape"):
|
269 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
270 |
+
T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
|
271 |
+
T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2])
|
272 |
+
var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
|
273 |
+
|
274 |
+
@T.prim_func
|
275 |
+
def reshape(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")):
|
276 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
277 |
+
# with T.block("root"):
|
278 |
+
for ax0 in range(T.int64(1)):
|
279 |
+
with T.block("T_reshape"):
|
280 |
+
v_ax0 = T.axis.spatial(T.int64(1), ax0)
|
281 |
+
T.reads(A[T.int64(0), T.int64(0)])
|
282 |
+
T.writes(T_reshape[v_ax0])
|
283 |
+
T_reshape[v_ax0] = A[T.int64(0), T.int64(0)]
|
284 |
+
|
285 |
+
@T.prim_func
|
286 |
+
def reshape1(A: T.Buffer((T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
287 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
288 |
+
# with T.block("root"):
|
289 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
290 |
+
with T.block("T_reshape"):
|
291 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
292 |
+
T.reads(A[T.int64(0), v_ax2 % T.int64(4096)])
|
293 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
|
294 |
+
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(4096)]
|
295 |
+
|
296 |
+
@T.prim_func
|
297 |
+
def reshape2(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")):
|
298 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
299 |
+
# with T.block("root"):
|
300 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
301 |
+
with T.block("T_reshape"):
|
302 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
303 |
+
T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
|
304 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
|
305 |
+
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
|
306 |
+
|
307 |
+
@T.prim_func
|
308 |
+
def rms_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), B: T.Buffer((T.int64(4096),), "float16"), rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
309 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
310 |
+
# with T.block("root"):
|
311 |
+
Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1)))
|
312 |
+
for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
313 |
+
with T.block("Ared_temp"):
|
314 |
+
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
|
315 |
+
T.reads(A[v_bsz, v_i, v_k])
|
316 |
+
T.writes(Ared_temp[v_bsz, v_i])
|
317 |
+
with T.init():
|
318 |
+
Ared_temp[v_bsz, v_i] = T.float32(0)
|
319 |
+
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
|
320 |
+
for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
321 |
+
with T.block("rms_norm"):
|
322 |
+
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
|
323 |
+
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
|
324 |
+
T.writes(rms_norm[v_bsz, v_i, v_k])
|
325 |
+
rms_norm[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
|
326 |
+
|
327 |
+
@T.prim_func
|
328 |
+
def rotary_embedding1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), n: T.int64):
|
329 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
330 |
+
# with T.block("root"):
|
331 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
332 |
+
with T.block("rotary"):
|
333 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
334 |
+
T.reads(B[n + v_i1 - T.int64(1), v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[n + v_i1 - T.int64(1), v_i3])
|
335 |
+
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
|
336 |
+
rotary[v_i0, v_i1, v_i2, v_i3] = B[n + v_i1 - T.int64(1), v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[n + v_i1 - T.int64(1), v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
|
337 |
+
|
338 |
+
@T.prim_func
|
339 |
+
def slice1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), slice: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
340 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
341 |
+
# with T.block("root"):
|
342 |
+
for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
343 |
+
with T.block("slice"):
|
344 |
+
v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
|
345 |
+
T.reads(A[v_i, T.int64(0), v_k])
|
346 |
+
T.writes(slice[v_i, v_j, v_k])
|
347 |
+
slice[v_i, v_j, v_k] = A[v_i, T.int64(0), v_k]
|
348 |
+
|
349 |
+
@T.prim_func
|
350 |
+
def softmax(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
|
351 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
352 |
+
# with T.block("root"):
|
353 |
+
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1)))
|
354 |
+
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)))
|
355 |
+
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1)))
|
356 |
+
for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
357 |
+
with T.block("T_softmax_maxelem"):
|
358 |
+
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
|
359 |
+
T.reads(A[v_i0, v_i1, v_k])
|
360 |
+
T.writes(T_softmax_maxelem[v_i0, v_i1])
|
361 |
+
with T.init():
|
362 |
+
T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38)
|
363 |
+
T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k])
|
364 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
365 |
+
with T.block("T_softmax_exp"):
|
366 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
367 |
+
T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
|
368 |
+
T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
|
369 |
+
T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
|
370 |
+
for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
371 |
+
with T.block("T_softmax_expsum"):
|
372 |
+
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
|
373 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_k])
|
374 |
+
T.writes(T_softmax_expsum[v_i0, v_i1])
|
375 |
+
with T.init():
|
376 |
+
T_softmax_expsum[v_i0, v_i1] = T.float32(0)
|
377 |
+
T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
|
378 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
379 |
+
with T.block("T_softmax_norm"):
|
380 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
381 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1])
|
382 |
+
T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
|
383 |
+
T.block_attr({"axis": 2})
|
384 |
+
T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
|
385 |
+
|
386 |
+
@T.prim_func
|
387 |
+
def squeeze(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
|
388 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
389 |
+
# with T.block("root"):
|
390 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
|
391 |
+
with T.block("T_squeeze"):
|
392 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
393 |
+
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
|
394 |
+
T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
|
395 |
+
T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
|
396 |
+
|
397 |
+
@T.prim_func
|
398 |
+
def take_decode(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode_1: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
|
399 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
400 |
+
# with T.block("root"):
|
401 |
+
for i, j in T.grid(T.int64(1), T.int64(4096)):
|
402 |
+
with T.block("take_decode"):
|
403 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
404 |
+
T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
|
405 |
+
T.writes(take_decode_1[v_i, v_j])
|
406 |
+
take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
|
407 |
+
|
408 |
+
@T.prim_func
|
409 |
+
def transpose3(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
|
410 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
411 |
+
# with T.block("root"):
|
412 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128)):
|
413 |
+
with T.block("T_transpose"):
|
414 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
415 |
+
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
|
416 |
+
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
|
417 |
+
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
|
418 |
+
# fmt: on
|