jetro30087
commited on
Commit
•
5c5d1d3
1
Parent(s):
1a8f52d
Upload 105 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +32 -3
- debug/mod_tir_dynamic.py +640 -0
- debug/mod_tir_static.py +364 -0
- mod_cache_before_build_android.pkl +3 -0
- params/mlc-chat-config.json +15 -0
- params/ndarray-cache.json +0 -0
- params/params_shard_0.bin +3 -0
- params/params_shard_1.bin +3 -0
- params/params_shard_10.bin +3 -0
- params/params_shard_100.bin +3 -0
- params/params_shard_101.bin +3 -0
- params/params_shard_102.bin +3 -0
- params/params_shard_103.bin +3 -0
- params/params_shard_104.bin +3 -0
- params/params_shard_105.bin +3 -0
- params/params_shard_106.bin +3 -0
- params/params_shard_107.bin +3 -0
- params/params_shard_108.bin +3 -0
- params/params_shard_109.bin +3 -0
- params/params_shard_11.bin +3 -0
- params/params_shard_110.bin +3 -0
- params/params_shard_111.bin +3 -0
- params/params_shard_112.bin +3 -0
- params/params_shard_113.bin +3 -0
- params/params_shard_114.bin +3 -0
- params/params_shard_115.bin +3 -0
- params/params_shard_116.bin +3 -0
- params/params_shard_117.bin +3 -0
- params/params_shard_118.bin +3 -0
- params/params_shard_119.bin +3 -0
- params/params_shard_12.bin +3 -0
- params/params_shard_120.bin +3 -0
- params/params_shard_121.bin +3 -0
- params/params_shard_122.bin +3 -0
- params/params_shard_123.bin +3 -0
- params/params_shard_124.bin +3 -0
- params/params_shard_125.bin +3 -0
- params/params_shard_126.bin +3 -0
- params/params_shard_127.bin +3 -0
- params/params_shard_128.bin +3 -0
- params/params_shard_129.bin +3 -0
- params/params_shard_13.bin +3 -0
- params/params_shard_14.bin +3 -0
- params/params_shard_15.bin +3 -0
- params/params_shard_16.bin +3 -0
- params/params_shard_17.bin +3 -0
- params/params_shard_18.bin +3 -0
- params/params_shard_19.bin +3 -0
- params/params_shard_2.bin +3 -0
- params/params_shard_20.bin +3 -0
README.md
CHANGED
@@ -1,3 +1,32 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model Card for vicuna-Wizard-7B-Uncensored-q3f16_0
|
2 |
+
Model Description
|
3 |
+
|
4 |
+
Note: Unlike the PC version, the Android MLC-LLM distribution does not have an option to edit the prompt configuration. This may result in unexpected responses.
|
5 |
+
|
6 |
+
This Language Model (vicuna-Wizard-7B-Uncensored-q3f16_0) is based on Facebook's "Llama" 7B parameter model, trained on the Wizard-Vicuna uncensored dataset under a non-commercial license. It was specifically developed and formatted for use within the MLC-LLM project, which you can find more details about at MLC-LLM project URL.
|
7 |
+
|
8 |
+
The model is designed for research and general text generation purposes. Thanks to MLC-LLM's Vulkan compatibility, the model is capable of working on both Nvidia and AMD graphics cards.
|
9 |
+
|
10 |
+
Model Usage
|
11 |
+
The vicuna-Wizard-7B-Uncensored-q3f16_0 model can generate human-like text that's useful for a variety of purposes, including but not limited to research, chatbots, writing aids, and more. You can use the model through MLC-LLM chat by copying it to the mlc-chat/dist folder of a compile MLC-Chat client.
|
12 |
+
|
13 |
+
Limitations and Bias
|
14 |
+
Although the model is capable of generating high-quality text, it is important to note that it is not perfect. Here are some potential limitations and biases:
|
15 |
+
|
16 |
+
Output quality: Although trained on a large dataset, the model may occasionally produce text that is nonsensical or does not align with the input prompt.
|
17 |
+
|
18 |
+
Biases in the data: The model has been trained on the Wizard-Vicuna uncensored dataset, and as such, it may have inherited biases present in this data. Despite our best efforts to minimize this, it may reflect biases in terms of gender, race, age, or other aspects.
|
19 |
+
|
20 |
+
Safety and content: The uncensored nature of the training dataset means that the model could potentially produce text that some people find offensive, inappropriate, or politically biased. We recommend using this model with care, especially in environments with young users or those who might be affected by such content.
|
21 |
+
|
22 |
+
Incorrect information: The model generates text based on patterns it learned during training and does not have access to real-world knowledge or updates beyond its training cut-off. As a result, the information it provides should always be verified for accuracy.
|
23 |
+
|
24 |
+
Ethical Considerations and Safety
|
25 |
+
While using this model, consider the following:
|
26 |
+
|
27 |
+
Always verify the information provided by the model with reliable external sources before using it to make decisions or for factual reference.
|
28 |
+
Monitor the output of the model for any potentially inappropriate or harmful content, especially if it is being used in a public or sensitive setting.
|
29 |
+
Keep in mind the potential biases inherited from the training data and account for these when interpreting the output.
|
30 |
+
Disclaimer
|
31 |
+
This model is provided as-is, and the developers make no warranties regarding its performance, appropriateness, or accuracy. Use it at your own risk.
|
32 |
+
license: othertions](https://mlc.ai/mlc-llm/docs/tutorials/runtime/cpp.html) for details.
|
debug/mod_tir_dynamic.py
ADDED
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from tvm.script import ir as I
|
3 |
+
from tvm.script import tir as T
|
4 |
+
|
5 |
+
# fmt: off
|
6 |
+
# from tvm.script import ir as I
|
7 |
+
# from tvm.script import tir as T
|
8 |
+
|
9 |
+
@I.ir_module
|
10 |
+
class Module:
|
11 |
+
@T.prim_func
|
12 |
+
def extend_te(var_A: T.handle, var_concat_te: T.handle):
|
13 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
14 |
+
n = T.int64()
|
15 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16")
|
16 |
+
m = T.int64()
|
17 |
+
concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16")
|
18 |
+
# with T.block("root"):
|
19 |
+
for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m):
|
20 |
+
with T.block("concat_te"):
|
21 |
+
v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j])
|
22 |
+
T.reads(A[v_b, v__, v_i, v_j + n - m])
|
23 |
+
T.writes(concat_te[v_b, v__, v_i, v_j])
|
24 |
+
concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m])
|
25 |
+
|
26 |
+
@T.prim_func
|
27 |
+
def full(var_T_full: T.handle):
|
28 |
+
T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
|
29 |
+
n = T.int64()
|
30 |
+
T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
|
31 |
+
# with T.block("root"):
|
32 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n):
|
33 |
+
with T.block("T_full"):
|
34 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
35 |
+
T.reads()
|
36 |
+
T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
|
37 |
+
T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504)
|
38 |
+
|
39 |
+
@T.prim_func
|
40 |
+
def fused_NT_matmul1_divide1_maximum_minimum_cast(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
|
41 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
42 |
+
n = T.int64()
|
43 |
+
lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
44 |
+
m = T.int64()
|
45 |
+
lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
|
46 |
+
lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16")
|
47 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m))
|
48 |
+
# with T.block("root"):
|
49 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
50 |
+
var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
51 |
+
var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
52 |
+
var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
|
53 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)):
|
54 |
+
with T.block("NT_matmul"):
|
55 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
56 |
+
T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k])
|
57 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
|
58 |
+
with T.init():
|
59 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
60 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k]
|
61 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
|
62 |
+
with T.block("T_divide"):
|
63 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
64 |
+
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
65 |
+
T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
66 |
+
var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
|
67 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
|
68 |
+
with T.block("T_maximum"):
|
69 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
70 |
+
T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
71 |
+
T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
72 |
+
var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
|
73 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
|
74 |
+
with T.block("T_minimum"):
|
75 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
76 |
+
T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
|
77 |
+
T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
78 |
+
var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
|
79 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
80 |
+
with T.block("compute"):
|
81 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
82 |
+
T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
83 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
84 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
85 |
+
|
86 |
+
@T.prim_func
|
87 |
+
def fused_NT_matmul4_divide2_maximum1_minimum1_cast3(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle):
|
88 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
89 |
+
n = T.int64()
|
90 |
+
lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
91 |
+
lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
|
92 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n))
|
93 |
+
# with T.block("root"):
|
94 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
95 |
+
var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
96 |
+
var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
97 |
+
var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
98 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):
|
99 |
+
with T.block("NT_matmul"):
|
100 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
101 |
+
T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k])
|
102 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
|
103 |
+
with T.init():
|
104 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
105 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k]
|
106 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
107 |
+
with T.block("T_divide"):
|
108 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
109 |
+
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
110 |
+
T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
111 |
+
var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
|
112 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
113 |
+
with T.block("T_maximum"):
|
114 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
115 |
+
T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
116 |
+
T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
117 |
+
var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
|
118 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
119 |
+
with T.block("T_minimum"):
|
120 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
121 |
+
T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
|
122 |
+
T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
123 |
+
var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
|
124 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
125 |
+
with T.block("compute"):
|
126 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
127 |
+
T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
128 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
129 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
|
130 |
+
|
131 |
+
@T.prim_func
|
132 |
+
def fused_decode1_NT_matmul(lv8: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv9: T.Buffer((T.int64(103), T.int64(4096)), "float16"), p_lv6: T.handle, p_output0: T.handle):
|
133 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
134 |
+
n = T.int64()
|
135 |
+
lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(4096)), "float16")
|
136 |
+
var_NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
|
137 |
+
# with T.block("root"):
|
138 |
+
decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
139 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
140 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
141 |
+
with T.block("decode"):
|
142 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
143 |
+
T.reads(lv8[v_i // T.int64(5), v_j], lv9[v_i // T.int64(40), v_j])
|
144 |
+
T.writes(decode[v_i, v_j])
|
145 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv8[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv9[v_i // T.int64(40), v_j]
|
146 |
+
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
|
147 |
+
with T.block("T_transpose"):
|
148 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
149 |
+
T.reads(decode[v_ax1, v_ax0])
|
150 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
|
151 |
+
var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
152 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
|
153 |
+
with T.block("NT_matmul"):
|
154 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
155 |
+
T.reads(lv6[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
|
156 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
157 |
+
with T.init():
|
158 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
159 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv6[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
|
160 |
+
|
161 |
+
@T.prim_func
|
162 |
+
def fused_decode1_fused_NT_matmul_add(lv29: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv30: T.Buffer((T.int64(103), T.int64(4096)), "float16"), p_lv41: T.handle, p_lv2: T.handle, p_output0: T.handle):
|
163 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
164 |
+
n = T.int64()
|
165 |
+
lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16")
|
166 |
+
lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16")
|
167 |
+
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
|
168 |
+
# with T.block("root"):
|
169 |
+
decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
170 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
171 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
|
172 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
173 |
+
with T.block("decode"):
|
174 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
175 |
+
T.reads(lv29[v_i // T.int64(5), v_j], lv30[v_i // T.int64(40), v_j])
|
176 |
+
T.writes(decode[v_i, v_j])
|
177 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv29[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv30[v_i // T.int64(40), v_j]
|
178 |
+
for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
|
179 |
+
with T.block("T_transpose"):
|
180 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
181 |
+
T.reads(decode[v_ax1, v_ax0])
|
182 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
|
183 |
+
var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
184 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
|
185 |
+
with T.block("NT_matmul"):
|
186 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
187 |
+
T.reads(lv41[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
|
188 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
189 |
+
with T.init():
|
190 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
191 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv41[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
|
192 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
193 |
+
with T.block("T_add"):
|
194 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
195 |
+
T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
196 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
197 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
198 |
+
|
199 |
+
@T.prim_func
|
200 |
+
def fused_decode2_fused_NT_matmul2_multiply(lv43: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv44: T.Buffer((T.int64(103), T.int64(11008)), "float16"), p_lv45: T.handle, p_lv132: T.handle, p_output0: T.handle):
|
201 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
202 |
+
n = T.int64()
|
203 |
+
lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
|
204 |
+
lv132 = T.match_buffer(p_lv132, (T.int64(1), n, T.int64(11008)), "float16")
|
205 |
+
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
|
206 |
+
# with T.block("root"):
|
207 |
+
decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
208 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
|
209 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
|
210 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
211 |
+
with T.block("decode"):
|
212 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
213 |
+
T.reads(lv43[v_i // T.int64(5), v_j], lv44[v_i // T.int64(40), v_j])
|
214 |
+
T.writes(decode[v_i, v_j])
|
215 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv43[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv44[v_i // T.int64(40), v_j]
|
216 |
+
for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
|
217 |
+
with T.block("T_transpose"):
|
218 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
219 |
+
T.reads(decode[v_ax1, v_ax0])
|
220 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
|
221 |
+
var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
222 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
|
223 |
+
with T.block("NT_matmul"):
|
224 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
225 |
+
T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
|
226 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
227 |
+
with T.init():
|
228 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
229 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
|
230 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
|
231 |
+
with T.block("T_multiply"):
|
232 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
233 |
+
T.reads(lv132[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
234 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
235 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv132[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
236 |
+
|
237 |
+
@T.prim_func
|
238 |
+
def fused_decode2_fused_NT_matmul2_silu(lv36: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv37: T.Buffer((T.int64(103), T.int64(11008)), "float16"), p_lv45: T.handle, p_output0: T.handle):
|
239 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
240 |
+
n = T.int64()
|
241 |
+
lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
|
242 |
+
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
|
243 |
+
# with T.block("root"):
|
244 |
+
decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
245 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
|
246 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
|
247 |
+
compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
|
248 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
249 |
+
with T.block("decode"):
|
250 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
251 |
+
T.reads(lv36[v_i // T.int64(5), v_j], lv37[v_i // T.int64(40), v_j])
|
252 |
+
T.writes(decode[v_i, v_j])
|
253 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv36[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv37[v_i // T.int64(40), v_j]
|
254 |
+
for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
|
255 |
+
with T.block("T_transpose"):
|
256 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
257 |
+
T.reads(decode[v_ax1, v_ax0])
|
258 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
|
259 |
+
var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
260 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
|
261 |
+
with T.block("NT_matmul"):
|
262 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
263 |
+
T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
|
264 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
265 |
+
with T.init():
|
266 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
267 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
|
268 |
+
for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)):
|
269 |
+
with T.block("compute"):
|
270 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
271 |
+
T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
272 |
+
T.writes(compute[v_i0, v_i1, v_i2])
|
273 |
+
compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
274 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
|
275 |
+
with T.block("T_multiply"):
|
276 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
277 |
+
T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
|
278 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
279 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
|
280 |
+
|
281 |
+
@T.prim_func
|
282 |
+
def fused_decode3_fused_NT_matmul3_add(lv50: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv51: T.Buffer((T.int64(276), T.int64(4096)), "float16"), p_lv5: T.handle, p_lv3: T.handle, p_output0: T.handle):
|
283 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
284 |
+
n = T.int64()
|
285 |
+
lv5 = T.match_buffer(p_lv5, (T.int64(1), n, T.int64(11008)), "float16")
|
286 |
+
lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
|
287 |
+
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
|
288 |
+
# with T.block("root"):
|
289 |
+
decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
|
290 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
291 |
+
var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
|
292 |
+
for i, j in T.grid(T.int64(11008), T.int64(4096)):
|
293 |
+
with T.block("decode"):
|
294 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
295 |
+
T.reads(lv50[v_i // T.int64(5), v_j], lv51[v_i // T.int64(40), v_j])
|
296 |
+
T.writes(decode[v_i, v_j])
|
297 |
+
decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv50[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv51[v_i // T.int64(40), v_j]
|
298 |
+
for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)):
|
299 |
+
with T.block("T_transpose"):
|
300 |
+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
|
301 |
+
T.reads(decode[v_ax1, v_ax0])
|
302 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
|
303 |
+
var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
|
304 |
+
for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):
|
305 |
+
with T.block("NT_matmul"):
|
306 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
307 |
+
T.reads(lv5[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
|
308 |
+
T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
|
309 |
+
with T.init():
|
310 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
311 |
+
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv5[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
|
312 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
313 |
+
with T.block("T_add"):
|
314 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
315 |
+
T.reads(lv3[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
316 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
317 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
318 |
+
|
319 |
+
@T.prim_func
|
320 |
+
def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64):
|
321 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
322 |
+
var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16")
|
323 |
+
# with T.block("root"):
|
324 |
+
var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16")
|
325 |
+
for i, j in T.grid(n, n):
|
326 |
+
with T.block("make_diag_mask_te"):
|
327 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
328 |
+
T.reads()
|
329 |
+
T.writes(var_make_diag_mask_te_intermediate[v_i, v_j])
|
330 |
+
var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504))
|
331 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n):
|
332 |
+
with T.block("T_broadcast_to"):
|
333 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
334 |
+
T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3])
|
335 |
+
T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
336 |
+
var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3]
|
337 |
+
|
338 |
+
@T.prim_func
|
339 |
+
def fused_softmax1_cast1(p_lv36: T.handle, p_output0: T.handle):
|
340 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
341 |
+
n, m = T.int64(), T.int64()
|
342 |
+
lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, m))
|
343 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16")
|
344 |
+
# with T.block("root"):
|
345 |
+
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n))
|
346 |
+
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
|
347 |
+
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n))
|
348 |
+
var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
|
349 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
|
350 |
+
with T.block("T_softmax_maxelem"):
|
351 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
352 |
+
T.reads(lv36[v_i0, v_i1, v_i2, v_k])
|
353 |
+
T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
|
354 |
+
with T.init():
|
355 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
|
356 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv36[v_i0, v_i1, v_i2, v_k])
|
357 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
358 |
+
with T.block("T_softmax_exp"):
|
359 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
360 |
+
T.reads(lv36[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
|
361 |
+
T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
|
362 |
+
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv36[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
|
363 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
|
364 |
+
with T.block("T_softmax_expsum"):
|
365 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
366 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
|
367 |
+
T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
|
368 |
+
with T.init():
|
369 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
|
370 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
|
371 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
372 |
+
with T.block("T_softmax_norm"):
|
373 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
374 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
|
375 |
+
T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
376 |
+
T.block_attr({"axis": 3})
|
377 |
+
var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
|
378 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
|
379 |
+
with T.block("compute"):
|
380 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
381 |
+
T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
382 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
383 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
384 |
+
|
385 |
+
@T.prim_func
|
386 |
+
def fused_softmax2_cast4(p_lv1613: T.handle, p_output0: T.handle):
|
387 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
388 |
+
n = T.int64()
|
389 |
+
lv1613 = T.match_buffer(p_lv1613, (T.int64(1), T.int64(32), T.int64(1), n))
|
390 |
+
var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
391 |
+
# with T.block("root"):
|
392 |
+
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
|
393 |
+
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
|
394 |
+
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
|
395 |
+
var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
|
396 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
397 |
+
with T.block("T_softmax_maxelem"):
|
398 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
399 |
+
T.reads(lv1613[v_i0, v_i1, v_i2, v_k])
|
400 |
+
T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
|
401 |
+
with T.init():
|
402 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
|
403 |
+
T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1613[v_i0, v_i1, v_i2, v_k])
|
404 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
405 |
+
with T.block("T_softmax_exp"):
|
406 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
407 |
+
T.reads(lv1613[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
|
408 |
+
T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
|
409 |
+
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1613[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
|
410 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
411 |
+
with T.block("T_softmax_expsum"):
|
412 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
413 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
|
414 |
+
T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
|
415 |
+
with T.init():
|
416 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
|
417 |
+
T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
|
418 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
419 |
+
with T.block("T_softmax_norm"):
|
420 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
421 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
|
422 |
+
T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
423 |
+
T.block_attr({"axis": 3})
|
424 |
+
var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
|
425 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
|
426 |
+
with T.block("compute"):
|
427 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
428 |
+
T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
429 |
+
T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
|
430 |
+
var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
|
431 |
+
|
432 |
+
@T.prim_func
|
433 |
+
def matmul3(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
|
434 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
435 |
+
n, m = T.int64(), T.int64()
|
436 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16")
|
437 |
+
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
|
438 |
+
matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
439 |
+
# with T.block("root"):
|
440 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m):
|
441 |
+
with T.block("matmul"):
|
442 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
443 |
+
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
|
444 |
+
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
|
445 |
+
with T.init():
|
446 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
447 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
|
448 |
+
|
449 |
+
@T.prim_func
|
450 |
+
def matmul8(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
|
451 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
452 |
+
n = T.int64()
|
453 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
|
454 |
+
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
455 |
+
# with T.block("root"):
|
456 |
+
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
|
457 |
+
with T.block("matmul"):
|
458 |
+
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
|
459 |
+
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
|
460 |
+
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
|
461 |
+
with T.init():
|
462 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
|
463 |
+
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
|
464 |
+
|
465 |
+
@T.prim_func
|
466 |
+
def reshape(var_A: T.handle, var_T_reshape: T.handle):
|
467 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
468 |
+
n = T.int64()
|
469 |
+
A = T.match_buffer(var_A, (T.int64(1), n), "int32")
|
470 |
+
T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
|
471 |
+
# with T.block("root"):
|
472 |
+
for ax0 in range(n):
|
473 |
+
with T.block("T_reshape"):
|
474 |
+
v_ax0 = T.axis.spatial(n, ax0)
|
475 |
+
T.reads(A[T.int64(0), v_ax0 % n])
|
476 |
+
T.writes(T_reshape[v_ax0])
|
477 |
+
T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n]
|
478 |
+
|
479 |
+
@T.prim_func
|
480 |
+
def reshape1(var_A: T.handle, var_T_reshape: T.handle):
|
481 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
482 |
+
n = T.int64()
|
483 |
+
A = T.match_buffer(var_A, (n, T.int64(4096)), "float16")
|
484 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
|
485 |
+
# with T.block("root"):
|
486 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
487 |
+
with T.block("T_reshape"):
|
488 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
489 |
+
T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)])
|
490 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
|
491 |
+
T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]
|
492 |
+
|
493 |
+
@T.prim_func
|
494 |
+
def reshape2(var_A: T.handle, var_T_reshape: T.handle):
|
495 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
496 |
+
n = T.int64()
|
497 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
498 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
499 |
+
# with T.block("root"):
|
500 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
501 |
+
with T.block("T_reshape"):
|
502 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
503 |
+
T.reads(A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
|
504 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
|
505 |
+
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
|
506 |
+
|
507 |
+
@T.prim_func
|
508 |
+
def reshape3(var_A: T.handle, var_T_reshape: T.handle):
|
509 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
510 |
+
m = T.int64()
|
511 |
+
A = T.match_buffer(var_A, (m, T.int64(32), T.int64(128)), "float16")
|
512 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), m, T.int64(32), T.int64(128)), "float16")
|
513 |
+
# with T.block("root"):
|
514 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), m, T.int64(32), T.int64(128)):
|
515 |
+
with T.block("T_reshape"):
|
516 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
517 |
+
T.reads(A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)])
|
518 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
|
519 |
+
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)]
|
520 |
+
|
521 |
+
@T.prim_func
|
522 |
+
def reshape4(var_A: T.handle, var_T_reshape: T.handle):
|
523 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
524 |
+
n = T.int64()
|
525 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
526 |
+
T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
|
527 |
+
# with T.block("root"):
|
528 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
|
529 |
+
with T.block("T_reshape"):
|
530 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
531 |
+
T.reads(A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
|
532 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
|
533 |
+
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
|
534 |
+
|
535 |
+
@T.prim_func
|
536 |
+
def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle):
|
537 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
538 |
+
n = T.int64()
|
539 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
540 |
+
rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
|
541 |
+
# with T.block("root"):
|
542 |
+
Ared_temp = T.alloc_buffer((T.int64(1), n))
|
543 |
+
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
|
544 |
+
with T.block("Ared_temp"):
|
545 |
+
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
|
546 |
+
T.reads(A[v_bsz, v_i, v_k])
|
547 |
+
T.writes(Ared_temp[v_bsz, v_i])
|
548 |
+
with T.init():
|
549 |
+
Ared_temp[v_bsz, v_i] = T.float32(0)
|
550 |
+
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
|
551 |
+
for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
|
552 |
+
with T.block("rms_norm"):
|
553 |
+
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
|
554 |
+
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
|
555 |
+
T.writes(rms_norm_1[v_bsz, v_i, v_k])
|
556 |
+
rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
|
557 |
+
|
558 |
+
@T.prim_func
|
559 |
+
def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, m: T.int64):
|
560 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
561 |
+
n = T.int64()
|
562 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
563 |
+
rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
564 |
+
# with T.block("root"):
|
565 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
566 |
+
with T.block("rotary"):
|
567 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
568 |
+
T.reads(B[m + v_i1 - n, v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[m + v_i1 - n, v_i3])
|
569 |
+
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
|
570 |
+
rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[m + v_i1 - n, v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
|
571 |
+
|
572 |
+
@T.prim_func
|
573 |
+
def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
574 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
575 |
+
n = T.int64()
|
576 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
|
577 |
+
# with T.block("root"):
|
578 |
+
for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
579 |
+
with T.block("slice"):
|
580 |
+
v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
|
581 |
+
T.reads(A[v_i, n - T.int64(1), v_k])
|
582 |
+
T.writes(slice_1[v_i, v_j, v_k])
|
583 |
+
slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k]
|
584 |
+
|
585 |
+
@T.prim_func
|
586 |
+
def squeeze(var_A: T.handle, var_T_squeeze: T.handle):
|
587 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
588 |
+
n = T.int64()
|
589 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
590 |
+
T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(128)), "float16")
|
591 |
+
# with T.block("root"):
|
592 |
+
for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(128)):
|
593 |
+
with T.block("T_squeeze"):
|
594 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
595 |
+
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
|
596 |
+
T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
|
597 |
+
T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
|
598 |
+
|
599 |
+
@T.prim_func
|
600 |
+
def take_decode(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), var_C: T.handle, var_take_decode: T.handle):
|
601 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
602 |
+
n = T.int64()
|
603 |
+
C = T.match_buffer(var_C, (n,), "int32")
|
604 |
+
take_decode_1 = T.match_buffer(var_take_decode, (n, T.int64(4096)), "float16")
|
605 |
+
# with T.block("root"):
|
606 |
+
for i, j in T.grid(n, T.int64(4096)):
|
607 |
+
with T.block("take_decode"):
|
608 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
609 |
+
T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
|
610 |
+
T.writes(take_decode_1[v_i, v_j])
|
611 |
+
take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
|
612 |
+
|
613 |
+
@T.prim_func
|
614 |
+
def transpose3(var_A: T.handle, var_T_transpose: T.handle):
|
615 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
616 |
+
n = T.int64()
|
617 |
+
A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
618 |
+
T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
619 |
+
# with T.block("root"):
|
620 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)):
|
621 |
+
with T.block("T_transpose"):
|
622 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
623 |
+
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
|
624 |
+
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
|
625 |
+
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
|
626 |
+
|
627 |
+
@T.prim_func
|
628 |
+
def transpose4(var_A: T.handle, var_T_transpose: T.handle):
|
629 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
630 |
+
n = T.int64()
|
631 |
+
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
|
632 |
+
T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
|
633 |
+
# with T.block("root"):
|
634 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
|
635 |
+
with T.block("T_transpose"):
|
636 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
637 |
+
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
|
638 |
+
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
|
639 |
+
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
|
640 |
+
# fmt: on
|
debug/mod_tir_static.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from tvm.script import ir as I
|
3 |
+
from tvm.script import tir as T
|
4 |
+
|
5 |
+
# fmt: off
|
6 |
+
# from tvm.script import ir as I
|
7 |
+
# from tvm.script import tir as T
|
8 |
+
|
9 |
+
@I.ir_module
|
10 |
+
class Module:
|
11 |
+
@T.prim_func
|
12 |
+
def divide(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
|
13 |
+
T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
|
14 |
+
# with T.block("root"):
|
15 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
16 |
+
with T.block("T_divide"):
|
17 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
18 |
+
T.reads(A[v_ax0, v_ax1, v_ax2], B[()])
|
19 |
+
T.writes(T_divide[v_ax0, v_ax1, v_ax2])
|
20 |
+
T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()]
|
21 |
+
|
22 |
+
@T.prim_func
|
23 |
+
def fused_decode4_fused_matmul4_cast2(lv2931: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv3152: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
|
24 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
25 |
+
# with T.block("root"):
|
26 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16")
|
27 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")
|
28 |
+
for i, j in T.grid(T.int64(4096), T.int64(32000)):
|
29 |
+
with T.block("decode"):
|
30 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
31 |
+
T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j])
|
32 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
33 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv2931[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j]
|
34 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):
|
35 |
+
with T.block("matmul"):
|
36 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
37 |
+
T.reads(lv3152[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
38 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
39 |
+
with T.init():
|
40 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
41 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3152[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
42 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
43 |
+
with T.block("compute"):
|
44 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
45 |
+
T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
46 |
+
T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
|
47 |
+
p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2])
|
48 |
+
|
49 |
+
@T.prim_func
|
50 |
+
def fused_decode5_fused_matmul7_add1(lv1605: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
51 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
52 |
+
# with T.block("root"):
|
53 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
54 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
|
55 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
56 |
+
with T.block("decode"):
|
57 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
58 |
+
T.reads(lv1605[v_i // T.int64(5), v_j], lv1606[v_i // T.int64(40), v_j])
|
59 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
60 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1605[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j]
|
61 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
|
62 |
+
with T.block("matmul"):
|
63 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
64 |
+
T.reads(lv197[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
65 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
66 |
+
with T.init():
|
67 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
68 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv197[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
69 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
70 |
+
with T.block("T_add"):
|
71 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
72 |
+
T.reads(lv1581[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
73 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
74 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1581[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
75 |
+
|
76 |
+
@T.prim_func
|
77 |
+
def fused_decode5_matmul7(lv1587: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
78 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
79 |
+
# with T.block("root"):
|
80 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
|
81 |
+
for i, j in T.grid(T.int64(4096), T.int64(4096)):
|
82 |
+
with T.block("decode"):
|
83 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
84 |
+
T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j])
|
85 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
86 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1587[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j]
|
87 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
|
88 |
+
with T.block("matmul"):
|
89 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
90 |
+
T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
91 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
92 |
+
with T.init():
|
93 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
94 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
95 |
+
|
96 |
+
@T.prim_func
|
97 |
+
def fused_decode6_fused_matmul9_multiply1(lv1617: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
|
98 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
99 |
+
# with T.block("root"):
|
100 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
101 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
|
102 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
103 |
+
with T.block("decode"):
|
104 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
105 |
+
T.reads(lv1617[v_i // T.int64(5), v_j], lv1618[v_i // T.int64(40), v_j])
|
106 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
107 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1617[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j]
|
108 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
|
109 |
+
with T.block("matmul"):
|
110 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
111 |
+
T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
112 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
113 |
+
with T.init():
|
114 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
115 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
116 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
|
117 |
+
with T.block("T_multiply"):
|
118 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
119 |
+
T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
120 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
121 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
122 |
+
|
123 |
+
@T.prim_func
|
124 |
+
def fused_decode6_fused_matmul9_silu1(lv1611: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
|
125 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
126 |
+
# with T.block("root"):
|
127 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
|
128 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
|
129 |
+
compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
|
130 |
+
for i, j in T.grid(T.int64(4096), T.int64(11008)):
|
131 |
+
with T.block("decode"):
|
132 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
133 |
+
T.reads(lv1611[v_i // T.int64(5), v_j], lv1612[v_i // T.int64(40), v_j])
|
134 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
135 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1611[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j]
|
136 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
|
137 |
+
with T.block("matmul"):
|
138 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
139 |
+
T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
140 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
141 |
+
with T.init():
|
142 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
143 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
144 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
|
145 |
+
with T.block("compute"):
|
146 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
147 |
+
T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
148 |
+
T.writes(compute[v_i0, v_i1, v_i2])
|
149 |
+
compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
150 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
|
151 |
+
with T.block("T_multiply"):
|
152 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
153 |
+
T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
|
154 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
155 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
|
156 |
+
|
157 |
+
@T.prim_func
|
158 |
+
def fused_decode7_fused_matmul10_add1(lv1623: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
159 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
160 |
+
# with T.block("root"):
|
161 |
+
var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
|
162 |
+
var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
|
163 |
+
for i, j in T.grid(T.int64(11008), T.int64(4096)):
|
164 |
+
with T.block("decode"):
|
165 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
166 |
+
T.reads(lv1623[v_i // T.int64(5), v_j], lv1624[v_i // T.int64(40), v_j])
|
167 |
+
T.writes(var_decode_intermediate[v_i, v_j])
|
168 |
+
var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1623[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j]
|
169 |
+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)):
|
170 |
+
with T.block("matmul"):
|
171 |
+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
|
172 |
+
T.reads(lv200[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
|
173 |
+
T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
|
174 |
+
with T.init():
|
175 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
|
176 |
+
var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv200[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
|
177 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
178 |
+
with T.block("T_add"):
|
179 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
180 |
+
T.reads(lv198[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
|
181 |
+
T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
|
182 |
+
p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv198[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
|
183 |
+
|
184 |
+
@T.prim_func
|
185 |
+
def fused_reshape7_squeeze1(lv1591: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
|
186 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
187 |
+
# with T.block("root"):
|
188 |
+
var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
|
189 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
190 |
+
with T.block("T_reshape"):
|
191 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
192 |
+
T.reads(lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
|
193 |
+
T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
194 |
+
var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
|
195 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
|
196 |
+
with T.block("T_squeeze"):
|
197 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
198 |
+
T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2])
|
199 |
+
T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2])
|
200 |
+
var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]
|
201 |
+
|
202 |
+
@T.prim_func
|
203 |
+
def fused_transpose7_reshape8(lv1616: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
204 |
+
T.func_attr({"tir.noalias": T.bool(True)})
|
205 |
+
# with T.block("root"):
|
206 |
+
var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
|
207 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
208 |
+
with T.block("T_transpose"):
|
209 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
210 |
+
T.reads(lv1616[v_ax0, v_ax2, v_ax1, v_ax3])
|
211 |
+
T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
|
212 |
+
var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1616[v_ax0, v_ax2, v_ax1, v_ax3]
|
213 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
214 |
+
with T.block("T_reshape"):
|
215 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
216 |
+
T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
|
217 |
+
T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2])
|
218 |
+
var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
|
219 |
+
|
220 |
+
@T.prim_func
|
221 |
+
def reshape5(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")):
|
222 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
223 |
+
# with T.block("root"):
|
224 |
+
for ax0 in range(T.int64(1)):
|
225 |
+
with T.block("T_reshape"):
|
226 |
+
v_ax0 = T.axis.spatial(T.int64(1), ax0)
|
227 |
+
T.reads(A[T.int64(0), T.int64(0)])
|
228 |
+
T.writes(T_reshape[v_ax0])
|
229 |
+
T_reshape[v_ax0] = A[T.int64(0), T.int64(0)]
|
230 |
+
|
231 |
+
@T.prim_func
|
232 |
+
def reshape6(A: T.Buffer((T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
233 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
234 |
+
# with T.block("root"):
|
235 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
236 |
+
with T.block("T_reshape"):
|
237 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
238 |
+
T.reads(A[T.int64(0), v_ax2 % T.int64(4096)])
|
239 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
|
240 |
+
T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(4096)]
|
241 |
+
|
242 |
+
@T.prim_func
|
243 |
+
def reshape7(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")):
|
244 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
245 |
+
# with T.block("root"):
|
246 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
247 |
+
with T.block("T_reshape"):
|
248 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
249 |
+
T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
|
250 |
+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
|
251 |
+
T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
|
252 |
+
|
253 |
+
@T.prim_func
|
254 |
+
def rms_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), B: T.Buffer((T.int64(4096),), "float16"), rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
255 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
256 |
+
# with T.block("root"):
|
257 |
+
Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1)))
|
258 |
+
for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
259 |
+
with T.block("Ared_temp"):
|
260 |
+
v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
|
261 |
+
T.reads(A[v_bsz, v_i, v_k])
|
262 |
+
T.writes(Ared_temp[v_bsz, v_i])
|
263 |
+
with T.init():
|
264 |
+
Ared_temp[v_bsz, v_i] = T.float32(0)
|
265 |
+
Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
|
266 |
+
for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
267 |
+
with T.block("rms_norm"):
|
268 |
+
v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
|
269 |
+
T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
|
270 |
+
T.writes(rms_norm[v_bsz, v_i, v_k])
|
271 |
+
rms_norm[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
|
272 |
+
|
273 |
+
@T.prim_func
|
274 |
+
def rotary_embedding1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), n: T.int64):
|
275 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
276 |
+
# with T.block("root"):
|
277 |
+
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
|
278 |
+
with T.block("rotary"):
|
279 |
+
v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
|
280 |
+
T.reads(B[n + v_i1 - T.int64(1), v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[n + v_i1 - T.int64(1), v_i3])
|
281 |
+
T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
|
282 |
+
rotary[v_i0, v_i1, v_i2, v_i3] = B[n + v_i1 - T.int64(1), v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[n + v_i1 - T.int64(1), v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
|
283 |
+
|
284 |
+
@T.prim_func
|
285 |
+
def slice1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), slice: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
|
286 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
287 |
+
# with T.block("root"):
|
288 |
+
for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
|
289 |
+
with T.block("slice"):
|
290 |
+
v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
|
291 |
+
T.reads(A[v_i, T.int64(0), v_k])
|
292 |
+
T.writes(slice[v_i, v_j, v_k])
|
293 |
+
slice[v_i, v_j, v_k] = A[v_i, T.int64(0), v_k]
|
294 |
+
|
295 |
+
@T.prim_func
|
296 |
+
def softmax(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
|
297 |
+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
|
298 |
+
# with T.block("root"):
|
299 |
+
T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1)))
|
300 |
+
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)))
|
301 |
+
T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1)))
|
302 |
+
for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
303 |
+
with T.block("T_softmax_maxelem"):
|
304 |
+
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
|
305 |
+
T.reads(A[v_i0, v_i1, v_k])
|
306 |
+
T.writes(T_softmax_maxelem[v_i0, v_i1])
|
307 |
+
with T.init():
|
308 |
+
T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38)
|
309 |
+
T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k])
|
310 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
311 |
+
with T.block("T_softmax_exp"):
|
312 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
313 |
+
T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
|
314 |
+
T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
|
315 |
+
T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
|
316 |
+
for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
317 |
+
with T.block("T_softmax_expsum"):
|
318 |
+
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
|
319 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_k])
|
320 |
+
T.writes(T_softmax_expsum[v_i0, v_i1])
|
321 |
+
with T.init():
|
322 |
+
T_softmax_expsum[v_i0, v_i1] = T.float32(0)
|
323 |
+
T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
|
324 |
+
for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
|
325 |
+
with T.block("T_softmax_norm"):
|
326 |
+
v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
|
327 |
+
T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1])
|
328 |
+
T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
|
329 |
+
T.block_attr({"axis": 2})
|
330 |
+
T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
|
331 |
+
|
332 |
+
@T.prim_func
|
333 |
+
def squeeze1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
|
334 |
+
T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
|
335 |
+
# with T.block("root"):
|
336 |
+
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
|
337 |
+
with T.block("T_squeeze"):
|
338 |
+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
|
339 |
+
T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
|
340 |
+
T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
|
341 |
+
T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
|
342 |
+
|
343 |
+
@T.prim_func
|
344 |
+
def take_decode1(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
|
345 |
+
T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
|
346 |
+
# with T.block("root"):
|
347 |
+
for i, j in T.grid(T.int64(1), T.int64(4096)):
|
348 |
+
with T.block("take_decode"):
|
349 |
+
v_i, v_j = T.axis.remap("SS", [i, j])
|
350 |
+
T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
|
351 |
+
T.writes(take_decode[v_i, v_j])
|
352 |
+
take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
|
353 |
+
|
354 |
+
@T.prim_func
|
355 |
+
def transpose6(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
|
356 |
+
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
|
357 |
+
# with T.block("root"):
|
358 |
+
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128)):
|
359 |
+
with T.block("T_transpose"):
|
360 |
+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
|
361 |
+
T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
|
362 |
+
T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
|
363 |
+
T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
|
364 |
+
# fmt: on
|
mod_cache_before_build_android.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:61f76a70ba5f4b9b97295c8d8497a2ce557191ef78236c1173d29ab4731775ac
|
3 |
+
size 33453240
|
params/mlc-chat-config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_lib": "vicuna-Wizard-7B-Uncensored-android-q3f16_0",
|
3 |
+
"local_id": "vicuna-Wizard-7B-Uncensored-android-q3f16_0",
|
4 |
+
"conv_template": "conv_one_shot",
|
5 |
+
"temperature": 0.7,
|
6 |
+
"repetition_penalty": 1.0,
|
7 |
+
"top_p": 0.95,
|
8 |
+
"mean_gen_len": 128,
|
9 |
+
"max_gen_len": 512,
|
10 |
+
"shift_fill_factor": 0.3,
|
11 |
+
"tokenizer_files": [
|
12 |
+
"tokenizer.json",
|
13 |
+
"tokenizer.model"
|
14 |
+
]
|
15 |
+
}
|
params/ndarray-cache.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
params/params_shard_0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:05e82cb7f41683ebca9c5dac10c6c6cb3114f64c1cb92174fa1ab2eb91c8bb58
|
3 |
+
size 52736000
|
params/params_shard_1.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f475ee094bbe8738202e5b90e34e6d738207ad9d60c80ccf10d779b507f27865
|
3 |
+
size 30955008
|
params/params_shard_10.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90fde15cc917dba6523e863f9ba68a06ffa9a9aa6314aa958639bb35f8bd8652
|
3 |
+
size 18141184
|
params/params_shard_100.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:45ffa798effa1f4dcc0e18fed876e630d5158048deab79b0ce001cb35201dd1f
|
3 |
+
size 29578240
|
params/params_shard_101.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:304c07013a1d51ab9044b56ebc9796f4ac9b1710bc209c4cece32f68551e1624
|
3 |
+
size 18141184
|
params/params_shard_102.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46c31d3d4ab8e92ded2da394d808016276379223f728bee7cfc8f0e61a0ad5c5
|
3 |
+
size 18141184
|
params/params_shard_103.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69be2551dd7b84c72230f40566ff587fc0ec7211cca086ef8a0e231b7fca7d23
|
3 |
+
size 32643584
|
params/params_shard_104.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89dae65b57366525d75ef3b4c3f9c0355e381501d41988590e9c26ba4af29932
|
3 |
+
size 30210560
|
params/params_shard_105.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca0dcea3d364a2ecbf5d0342013a1b05682d68fc941c069ed87af1447ca5d801
|
3 |
+
size 18141184
|
params/params_shard_106.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fee3ea256c0ae4d22230451a2dac4924975b715daad071f15734d4bf3bc88290
|
3 |
+
size 18141184
|
params/params_shard_107.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0dfd94eea61cc08152dd9ffd799a2fbc01eb9b528e7109e80b4746119fe252eb
|
3 |
+
size 18087936
|
params/params_shard_108.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43caa8e093339b27fd7d269972a04f0e6a3528b40ce7dbcffa1845e0b0202ed4
|
3 |
+
size 29578240
|
params/params_shard_109.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ea8b8e5a91f56582dcce624e269c5d93f9ca921949711967d3b884e50d3c13f
|
3 |
+
size 18141184
|
params/params_shard_11.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9bc9fa98e51b0473c5eec445fb8816069e0f36455e7026cfb970a9268e70e80b
|
3 |
+
size 18087936
|
params/params_shard_110.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79ddd8fafc513a13bf2cccd87349a5af7b113c78c44ea8190a74b2c992cab689
|
3 |
+
size 18141184
|
params/params_shard_111.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e731157580eaa87e96639a3f2b754259030a395360bf359125923844559ac6c6
|
3 |
+
size 32643584
|
params/params_shard_112.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e37507121ee529dcf48c793649622a63ed54b09f06b2fa41ff07405f61de6e2
|
3 |
+
size 30210560
|
params/params_shard_113.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7220a590e7a69a483487510e9b40144965611a5f0900ba9485cc6f6efc07b29
|
3 |
+
size 18141184
|
params/params_shard_114.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b8edc3a8867cb78f134e6a8658fc50204b5532a601bd7fcdcae4ad7883104b0
|
3 |
+
size 18141184
|
params/params_shard_115.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b9c50d47d81180c3fc15f01cc7e2adcf3c37f8f6465eab91395d4fa7f0e8f33
|
3 |
+
size 18087936
|
params/params_shard_116.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2f60e485c53239dc06cc00de12eed418233c4252229a9083add27b87c067147
|
3 |
+
size 29578240
|
params/params_shard_117.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:212d6a7e47c4bf4fcff8833ca24e0d9511802adee6b9cae168b4ff0055282cd3
|
3 |
+
size 18141184
|
params/params_shard_118.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65aa7f26130e465fb9444786e399b04c84c220cb23f9f1cd855ca7a14464be88
|
3 |
+
size 18141184
|
params/params_shard_119.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f79cdedfcfd480c63c21017cbe1097fb6e6c1d2f8df7b0600efe466a5430cc29
|
3 |
+
size 32643584
|
params/params_shard_12.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:190f3293399e9520a5efe458f5a9288e03e2faf5adc538bf35d31d9b4673d5e1
|
3 |
+
size 29578240
|
params/params_shard_120.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d01cbd2eb461c329dfc89c9826bca789d4fb5b2b4afcb48f66c81f41d7408abe
|
3 |
+
size 30210560
|
params/params_shard_121.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:794717e19321f502f178bda5d8ea616fb48c19907c2f8747a8d50fd20b9d7f77
|
3 |
+
size 18141184
|
params/params_shard_122.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bf6c6bc03172f2a58beea56dce92aa2ed933358e7a97e98bfe06e8b747cff8f
|
3 |
+
size 18141184
|
params/params_shard_123.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76d3cf47f15deda8a3598fecab044153e11c4fe88a2c6bc58e65fa2f8cdddc7c
|
3 |
+
size 18087936
|
params/params_shard_124.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8db9fbc83e5a9fe7d4e294540efaefb325ba5573bdc0fb20ea6c715eaaf3f7a0
|
3 |
+
size 29578240
|
params/params_shard_125.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8f2963ec1e598e6ee85a91c21dce132978347a4abc812b444b9f900ea3d45b6
|
3 |
+
size 18141184
|
params/params_shard_126.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8fb390ac64f23f8782d141f6f72e11b5aa9b51bb211fb16f4b3a61bf3203784
|
3 |
+
size 18141184
|
params/params_shard_127.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:012dd36b4436148eaae04869543827c5e49c29257fe383ecdd7e8bfa257d00cf
|
3 |
+
size 32643584
|
params/params_shard_128.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d3304f8d32e19db6f4b19c07893ad06558cddf74f19de8c9d02da2418b23829
|
3 |
+
size 52736000
|
params/params_shard_129.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97a4dd35920543a930cef9887f429cc7a7cc73aff0ddd010e9a9407e23efd825
|
3 |
+
size 29208576
|
params/params_shard_13.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83d54fb1a8389d74cf9068c2c363372649ac0b46c9c15f9632b8c17bd9f77c00
|
3 |
+
size 18141184
|
params/params_shard_14.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:60e75e756b9d83d6e5b522144a9f958d1c03ee858f8fad5e4c9e6312cd391763
|
3 |
+
size 18141184
|
params/params_shard_15.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7eb0e11ce7c95f3ce42192fc27a83674e8e5da6f328353f8e7a269a952c92502
|
3 |
+
size 32643584
|
params/params_shard_16.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8375ce114ca4e3dbf33f39cb697f2cd5b9fbf62a05a48a12e0ac40e1d782929
|
3 |
+
size 30210560
|
params/params_shard_17.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:537be9cd347dac60a3cae2fde68c8b0fdc5c488508164aa86fc5f3fe7938f67e
|
3 |
+
size 18141184
|
params/params_shard_18.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e8a3464fef817a62888684d12c56dfad343f79a68825c83278460c562a8f95c
|
3 |
+
size 18141184
|
params/params_shard_19.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1d615fc53d422efa5d54033a79c34e0ca66b52bacc8b30d75163de2bd9821ec
|
3 |
+
size 18087936
|
params/params_shard_2.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a5d64f7b0ff1fe6315b220755b175a2f4641e0e66591d7b25cf341d88b65afa
|
3 |
+
size 18141184
|
params/params_shard_20.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f977dce5c6b01dda9e9c6b7524ce35cf5adc485cab682c7340b2fa06a231aa77
|
3 |
+
size 29578240
|