jetro30087 commited on
Commit
5c5d1d3
1 Parent(s): 1a8f52d

Upload 105 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +32 -3
  2. debug/mod_tir_dynamic.py +640 -0
  3. debug/mod_tir_static.py +364 -0
  4. mod_cache_before_build_android.pkl +3 -0
  5. params/mlc-chat-config.json +15 -0
  6. params/ndarray-cache.json +0 -0
  7. params/params_shard_0.bin +3 -0
  8. params/params_shard_1.bin +3 -0
  9. params/params_shard_10.bin +3 -0
  10. params/params_shard_100.bin +3 -0
  11. params/params_shard_101.bin +3 -0
  12. params/params_shard_102.bin +3 -0
  13. params/params_shard_103.bin +3 -0
  14. params/params_shard_104.bin +3 -0
  15. params/params_shard_105.bin +3 -0
  16. params/params_shard_106.bin +3 -0
  17. params/params_shard_107.bin +3 -0
  18. params/params_shard_108.bin +3 -0
  19. params/params_shard_109.bin +3 -0
  20. params/params_shard_11.bin +3 -0
  21. params/params_shard_110.bin +3 -0
  22. params/params_shard_111.bin +3 -0
  23. params/params_shard_112.bin +3 -0
  24. params/params_shard_113.bin +3 -0
  25. params/params_shard_114.bin +3 -0
  26. params/params_shard_115.bin +3 -0
  27. params/params_shard_116.bin +3 -0
  28. params/params_shard_117.bin +3 -0
  29. params/params_shard_118.bin +3 -0
  30. params/params_shard_119.bin +3 -0
  31. params/params_shard_12.bin +3 -0
  32. params/params_shard_120.bin +3 -0
  33. params/params_shard_121.bin +3 -0
  34. params/params_shard_122.bin +3 -0
  35. params/params_shard_123.bin +3 -0
  36. params/params_shard_124.bin +3 -0
  37. params/params_shard_125.bin +3 -0
  38. params/params_shard_126.bin +3 -0
  39. params/params_shard_127.bin +3 -0
  40. params/params_shard_128.bin +3 -0
  41. params/params_shard_129.bin +3 -0
  42. params/params_shard_13.bin +3 -0
  43. params/params_shard_14.bin +3 -0
  44. params/params_shard_15.bin +3 -0
  45. params/params_shard_16.bin +3 -0
  46. params/params_shard_17.bin +3 -0
  47. params/params_shard_18.bin +3 -0
  48. params/params_shard_19.bin +3 -0
  49. params/params_shard_2.bin +3 -0
  50. params/params_shard_20.bin +3 -0
README.md CHANGED
@@ -1,3 +1,32 @@
1
- ---
2
- license: other
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model Card for vicuna-Wizard-7B-Uncensored-q3f16_0
2
+ Model Description
3
+
4
+ Note: Unlike the PC version, the Android MLC-LLM distribution does not have an option to edit the prompt configuration. This may result in unexpected responses.
5
+
6
+ This Language Model (vicuna-Wizard-7B-Uncensored-q3f16_0) is based on Facebook's "Llama" 7B parameter model, trained on the Wizard-Vicuna uncensored dataset under a non-commercial license. It was specifically developed and formatted for use within the MLC-LLM project, which you can find more details about at MLC-LLM project URL.
7
+
8
+ The model is designed for research and general text generation purposes. Thanks to MLC-LLM's Vulkan compatibility, the model is capable of working on both Nvidia and AMD graphics cards.
9
+
10
+ Model Usage
11
+ The vicuna-Wizard-7B-Uncensored-q3f16_0 model can generate human-like text that's useful for a variety of purposes, including but not limited to research, chatbots, writing aids, and more. You can use the model through MLC-LLM chat by copying it to the mlc-chat/dist folder of a compile MLC-Chat client.
12
+
13
+ Limitations and Bias
14
+ Although the model is capable of generating high-quality text, it is important to note that it is not perfect. Here are some potential limitations and biases:
15
+
16
+ Output quality: Although trained on a large dataset, the model may occasionally produce text that is nonsensical or does not align with the input prompt.
17
+
18
+ Biases in the data: The model has been trained on the Wizard-Vicuna uncensored dataset, and as such, it may have inherited biases present in this data. Despite our best efforts to minimize this, it may reflect biases in terms of gender, race, age, or other aspects.
19
+
20
+ Safety and content: The uncensored nature of the training dataset means that the model could potentially produce text that some people find offensive, inappropriate, or politically biased. We recommend using this model with care, especially in environments with young users or those who might be affected by such content.
21
+
22
+ Incorrect information: The model generates text based on patterns it learned during training and does not have access to real-world knowledge or updates beyond its training cut-off. As a result, the information it provides should always be verified for accuracy.
23
+
24
+ Ethical Considerations and Safety
25
+ While using this model, consider the following:
26
+
27
+ Always verify the information provided by the model with reliable external sources before using it to make decisions or for factual reference.
28
+ Monitor the output of the model for any potentially inappropriate or harmful content, especially if it is being used in a public or sensitive setting.
29
+ Keep in mind the potential biases inherited from the training data and account for these when interpreting the output.
30
+ Disclaimer
31
+ This model is provided as-is, and the developers make no warranties regarding its performance, appropriateness, or accuracy. Use it at your own risk.
32
+ license: othertions](https://mlc.ai/mlc-llm/docs/tutorials/runtime/cpp.html) for details.
debug/mod_tir_dynamic.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from tvm.script import ir as I
3
+ from tvm.script import tir as T
4
+
5
+ # fmt: off
6
+ # from tvm.script import ir as I
7
+ # from tvm.script import tir as T
8
+
9
+ @I.ir_module
10
+ class Module:
11
+ @T.prim_func
12
+ def extend_te(var_A: T.handle, var_concat_te: T.handle):
13
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
14
+ n = T.int64()
15
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16")
16
+ m = T.int64()
17
+ concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16")
18
+ # with T.block("root"):
19
+ for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m):
20
+ with T.block("concat_te"):
21
+ v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j])
22
+ T.reads(A[v_b, v__, v_i, v_j + n - m])
23
+ T.writes(concat_te[v_b, v__, v_i, v_j])
24
+ concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m])
25
+
26
+ @T.prim_func
27
+ def full(var_T_full: T.handle):
28
+ T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
29
+ n = T.int64()
30
+ T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
31
+ # with T.block("root"):
32
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n):
33
+ with T.block("T_full"):
34
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
35
+ T.reads()
36
+ T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
37
+ T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504)
38
+
39
+ @T.prim_func
40
+ def fused_NT_matmul1_divide1_maximum_minimum_cast(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle):
41
+ T.func_attr({"tir.noalias": T.bool(True)})
42
+ n = T.int64()
43
+ lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
44
+ m = T.int64()
45
+ lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
46
+ lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16")
47
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m))
48
+ # with T.block("root"):
49
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
50
+ var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
51
+ var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
52
+ var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16")
53
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)):
54
+ with T.block("NT_matmul"):
55
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
56
+ T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k])
57
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
58
+ with T.init():
59
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
60
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k]
61
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
62
+ with T.block("T_divide"):
63
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
64
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
65
+ T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
66
+ var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
67
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
68
+ with T.block("T_maximum"):
69
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
70
+ T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
71
+ T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
72
+ var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
73
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m):
74
+ with T.block("T_minimum"):
75
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
76
+ T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
77
+ T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
78
+ var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3])
79
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
80
+ with T.block("compute"):
81
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
82
+ T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
83
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
84
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
85
+
86
+ @T.prim_func
87
+ def fused_NT_matmul4_divide2_maximum1_minimum1_cast3(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle):
88
+ T.func_attr({"tir.noalias": T.bool(True)})
89
+ n = T.int64()
90
+ lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
91
+ lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16")
92
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n))
93
+ # with T.block("root"):
94
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
95
+ var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
96
+ var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
97
+ var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16")
98
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)):
99
+ with T.block("NT_matmul"):
100
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
101
+ T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k])
102
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3])
103
+ with T.init():
104
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
105
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k]
106
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
107
+ with T.block("T_divide"):
108
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
109
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
110
+ T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
111
+ var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615)
112
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
113
+ with T.block("T_maximum"):
114
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
115
+ T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
116
+ T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
117
+ var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504))
118
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
119
+ with T.block("T_minimum"):
120
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
121
+ T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
122
+ T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
123
+ var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3])
124
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
125
+ with T.block("compute"):
126
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
127
+ T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
128
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
129
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3])
130
+
131
+ @T.prim_func
132
+ def fused_decode1_NT_matmul(lv8: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv9: T.Buffer((T.int64(103), T.int64(4096)), "float16"), p_lv6: T.handle, p_output0: T.handle):
133
+ T.func_attr({"tir.noalias": T.bool(True)})
134
+ n = T.int64()
135
+ lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(4096)), "float16")
136
+ var_NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
137
+ # with T.block("root"):
138
+ decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
139
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
140
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
141
+ with T.block("decode"):
142
+ v_i, v_j = T.axis.remap("SS", [i, j])
143
+ T.reads(lv8[v_i // T.int64(5), v_j], lv9[v_i // T.int64(40), v_j])
144
+ T.writes(decode[v_i, v_j])
145
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv8[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv9[v_i // T.int64(40), v_j]
146
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
147
+ with T.block("T_transpose"):
148
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
149
+ T.reads(decode[v_ax1, v_ax0])
150
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
151
+ var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
152
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
153
+ with T.block("NT_matmul"):
154
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
155
+ T.reads(lv6[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
156
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
157
+ with T.init():
158
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
159
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv6[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
160
+
161
+ @T.prim_func
162
+ def fused_decode1_fused_NT_matmul_add(lv29: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv30: T.Buffer((T.int64(103), T.int64(4096)), "float16"), p_lv41: T.handle, p_lv2: T.handle, p_output0: T.handle):
163
+ T.func_attr({"tir.noalias": T.bool(True)})
164
+ n = T.int64()
165
+ lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16")
166
+ lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16")
167
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
168
+ # with T.block("root"):
169
+ decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
170
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
171
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
172
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
173
+ with T.block("decode"):
174
+ v_i, v_j = T.axis.remap("SS", [i, j])
175
+ T.reads(lv29[v_i // T.int64(5), v_j], lv30[v_i // T.int64(40), v_j])
176
+ T.writes(decode[v_i, v_j])
177
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv29[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv30[v_i // T.int64(40), v_j]
178
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
179
+ with T.block("T_transpose"):
180
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
181
+ T.reads(decode[v_ax1, v_ax0])
182
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
183
+ var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
184
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)):
185
+ with T.block("NT_matmul"):
186
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
187
+ T.reads(lv41[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
188
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
189
+ with T.init():
190
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
191
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv41[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
192
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
193
+ with T.block("T_add"):
194
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
195
+ T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
196
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
197
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
198
+
199
+ @T.prim_func
200
+ def fused_decode2_fused_NT_matmul2_multiply(lv43: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv44: T.Buffer((T.int64(103), T.int64(11008)), "float16"), p_lv45: T.handle, p_lv132: T.handle, p_output0: T.handle):
201
+ T.func_attr({"tir.noalias": T.bool(True)})
202
+ n = T.int64()
203
+ lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
204
+ lv132 = T.match_buffer(p_lv132, (T.int64(1), n, T.int64(11008)), "float16")
205
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
206
+ # with T.block("root"):
207
+ decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
208
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
209
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
210
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
211
+ with T.block("decode"):
212
+ v_i, v_j = T.axis.remap("SS", [i, j])
213
+ T.reads(lv43[v_i // T.int64(5), v_j], lv44[v_i // T.int64(40), v_j])
214
+ T.writes(decode[v_i, v_j])
215
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv43[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv44[v_i // T.int64(40), v_j]
216
+ for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
217
+ with T.block("T_transpose"):
218
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
219
+ T.reads(decode[v_ax1, v_ax0])
220
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
221
+ var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
222
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
223
+ with T.block("NT_matmul"):
224
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
225
+ T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
226
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
227
+ with T.init():
228
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
229
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
230
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
231
+ with T.block("T_multiply"):
232
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
233
+ T.reads(lv132[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
234
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
235
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv132[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
236
+
237
+ @T.prim_func
238
+ def fused_decode2_fused_NT_matmul2_silu(lv36: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv37: T.Buffer((T.int64(103), T.int64(11008)), "float16"), p_lv45: T.handle, p_output0: T.handle):
239
+ T.func_attr({"tir.noalias": T.bool(True)})
240
+ n = T.int64()
241
+ lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16")
242
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16")
243
+ # with T.block("root"):
244
+ decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
245
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
246
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
247
+ compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16")
248
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
249
+ with T.block("decode"):
250
+ v_i, v_j = T.axis.remap("SS", [i, j])
251
+ T.reads(lv36[v_i // T.int64(5), v_j], lv37[v_i // T.int64(40), v_j])
252
+ T.writes(decode[v_i, v_j])
253
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv36[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv37[v_i // T.int64(40), v_j]
254
+ for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
255
+ with T.block("T_transpose"):
256
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
257
+ T.reads(decode[v_ax1, v_ax0])
258
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
259
+ var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
260
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)):
261
+ with T.block("NT_matmul"):
262
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
263
+ T.reads(lv45[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
264
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
265
+ with T.init():
266
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
267
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
268
+ for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)):
269
+ with T.block("compute"):
270
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
271
+ T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
272
+ T.writes(compute[v_i0, v_i1, v_i2])
273
+ compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
274
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)):
275
+ with T.block("T_multiply"):
276
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
277
+ T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
278
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
279
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
280
+
281
+ @T.prim_func
282
+ def fused_decode3_fused_NT_matmul3_add(lv50: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv51: T.Buffer((T.int64(276), T.int64(4096)), "float16"), p_lv5: T.handle, p_lv3: T.handle, p_output0: T.handle):
283
+ T.func_attr({"tir.noalias": T.bool(True)})
284
+ n = T.int64()
285
+ lv5 = T.match_buffer(p_lv5, (T.int64(1), n, T.int64(11008)), "float16")
286
+ lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
287
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16")
288
+ # with T.block("root"):
289
+ decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
290
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
291
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16")
292
+ for i, j in T.grid(T.int64(11008), T.int64(4096)):
293
+ with T.block("decode"):
294
+ v_i, v_j = T.axis.remap("SS", [i, j])
295
+ T.reads(lv50[v_i // T.int64(5), v_j], lv51[v_i // T.int64(40), v_j])
296
+ T.writes(decode[v_i, v_j])
297
+ decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv50[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv51[v_i // T.int64(40), v_j]
298
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)):
299
+ with T.block("T_transpose"):
300
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
301
+ T.reads(decode[v_ax1, v_ax0])
302
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1])
303
+ var_T_transpose_intermediate[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
304
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)):
305
+ with T.block("NT_matmul"):
306
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
307
+ T.reads(lv5[v_i0, v_i1, v_k], var_T_transpose_intermediate[v_i2, v_k])
308
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
309
+ with T.init():
310
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
311
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv5[v_i0, v_i1, v_k] * var_T_transpose_intermediate[v_i2, v_k]
312
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
313
+ with T.block("T_add"):
314
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
315
+ T.reads(lv3[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
316
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
317
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
318
+
319
+ @T.prim_func
320
+ def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64):
321
+ T.func_attr({"tir.noalias": T.bool(True)})
322
+ var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16")
323
+ # with T.block("root"):
324
+ var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16")
325
+ for i, j in T.grid(n, n):
326
+ with T.block("make_diag_mask_te"):
327
+ v_i, v_j = T.axis.remap("SS", [i, j])
328
+ T.reads()
329
+ T.writes(var_make_diag_mask_te_intermediate[v_i, v_j])
330
+ var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504))
331
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n):
332
+ with T.block("T_broadcast_to"):
333
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
334
+ T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3])
335
+ T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
336
+ var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3]
337
+
338
+ @T.prim_func
339
+ def fused_softmax1_cast1(p_lv36: T.handle, p_output0: T.handle):
340
+ T.func_attr({"tir.noalias": T.bool(True)})
341
+ n, m = T.int64(), T.int64()
342
+ lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, m))
343
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16")
344
+ # with T.block("root"):
345
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n))
346
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
347
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n))
348
+ var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
349
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
350
+ with T.block("T_softmax_maxelem"):
351
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
352
+ T.reads(lv36[v_i0, v_i1, v_i2, v_k])
353
+ T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
354
+ with T.init():
355
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
356
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv36[v_i0, v_i1, v_i2, v_k])
357
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
358
+ with T.block("T_softmax_exp"):
359
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
360
+ T.reads(lv36[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
361
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
362
+ T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv36[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
363
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
364
+ with T.block("T_softmax_expsum"):
365
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
366
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
367
+ T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
368
+ with T.init():
369
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
370
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
371
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
372
+ with T.block("T_softmax_norm"):
373
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
374
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
375
+ T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
376
+ T.block_attr({"axis": 3})
377
+ var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
378
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
379
+ with T.block("compute"):
380
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
381
+ T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
382
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
383
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
384
+
385
+ @T.prim_func
386
+ def fused_softmax2_cast4(p_lv1613: T.handle, p_output0: T.handle):
387
+ T.func_attr({"tir.noalias": T.bool(True)})
388
+ n = T.int64()
389
+ lv1613 = T.match_buffer(p_lv1613, (T.int64(1), T.int64(32), T.int64(1), n))
390
+ var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
391
+ # with T.block("root"):
392
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
393
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
394
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
395
+ var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n))
396
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
397
+ with T.block("T_softmax_maxelem"):
398
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
399
+ T.reads(lv1613[v_i0, v_i1, v_i2, v_k])
400
+ T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
401
+ with T.init():
402
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38)
403
+ T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1613[v_i0, v_i1, v_i2, v_k])
404
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
405
+ with T.block("T_softmax_exp"):
406
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
407
+ T.reads(lv1613[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2])
408
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
409
+ T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1613[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
410
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
411
+ with T.block("T_softmax_expsum"):
412
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
413
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
414
+ T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
415
+ with T.init():
416
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
417
+ T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
418
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
419
+ with T.block("T_softmax_norm"):
420
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
421
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2])
422
+ T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
423
+ T.block_attr({"axis": 3})
424
+ var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
425
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n):
426
+ with T.block("compute"):
427
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
428
+ T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
429
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
430
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
431
+
432
+ @T.prim_func
433
+ def matmul3(var_A: T.handle, var_B: T.handle, var_matmul: T.handle):
434
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
435
+ n, m = T.int64(), T.int64()
436
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16")
437
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(128)), "float16")
438
+ matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
439
+ # with T.block("root"):
440
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m):
441
+ with T.block("matmul"):
442
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
443
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
444
+ T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
445
+ with T.init():
446
+ matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
447
+ matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
448
+
449
+ @T.prim_func
450
+ def matmul8(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
451
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
452
+ n = T.int64()
453
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16")
454
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
455
+ # with T.block("root"):
456
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n):
457
+ with T.block("matmul"):
458
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
459
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
460
+ T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
461
+ with T.init():
462
+ matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
463
+ matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
464
+
465
+ @T.prim_func
466
+ def reshape(var_A: T.handle, var_T_reshape: T.handle):
467
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
468
+ n = T.int64()
469
+ A = T.match_buffer(var_A, (T.int64(1), n), "int32")
470
+ T_reshape = T.match_buffer(var_T_reshape, (n,), "int32")
471
+ # with T.block("root"):
472
+ for ax0 in range(n):
473
+ with T.block("T_reshape"):
474
+ v_ax0 = T.axis.spatial(n, ax0)
475
+ T.reads(A[T.int64(0), v_ax0 % n])
476
+ T.writes(T_reshape[v_ax0])
477
+ T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n]
478
+
479
+ @T.prim_func
480
+ def reshape1(var_A: T.handle, var_T_reshape: T.handle):
481
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
482
+ n = T.int64()
483
+ A = T.match_buffer(var_A, (n, T.int64(4096)), "float16")
484
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
485
+ # with T.block("root"):
486
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
487
+ with T.block("T_reshape"):
488
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
489
+ T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)])
490
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
491
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]
492
+
493
+ @T.prim_func
494
+ def reshape2(var_A: T.handle, var_T_reshape: T.handle):
495
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
496
+ n = T.int64()
497
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
498
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
499
+ # with T.block("root"):
500
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
501
+ with T.block("T_reshape"):
502
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
503
+ T.reads(A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
504
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
505
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
506
+
507
+ @T.prim_func
508
+ def reshape3(var_A: T.handle, var_T_reshape: T.handle):
509
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
510
+ m = T.int64()
511
+ A = T.match_buffer(var_A, (m, T.int64(32), T.int64(128)), "float16")
512
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), m, T.int64(32), T.int64(128)), "float16")
513
+ # with T.block("root"):
514
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), m, T.int64(32), T.int64(128)):
515
+ with T.block("T_reshape"):
516
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
517
+ T.reads(A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)])
518
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
519
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * m + v_ax1) % m, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)]
520
+
521
+ @T.prim_func
522
+ def reshape4(var_A: T.handle, var_T_reshape: T.handle):
523
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
524
+ n = T.int64()
525
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
526
+ T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16")
527
+ # with T.block("root"):
528
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
529
+ with T.block("T_reshape"):
530
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
531
+ T.reads(A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
532
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
533
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
534
+
535
+ @T.prim_func
536
+ def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle):
537
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
538
+ n = T.int64()
539
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
540
+ rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16")
541
+ # with T.block("root"):
542
+ Ared_temp = T.alloc_buffer((T.int64(1), n))
543
+ for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
544
+ with T.block("Ared_temp"):
545
+ v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
546
+ T.reads(A[v_bsz, v_i, v_k])
547
+ T.writes(Ared_temp[v_bsz, v_i])
548
+ with T.init():
549
+ Ared_temp[v_bsz, v_i] = T.float32(0)
550
+ Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
551
+ for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
552
+ with T.block("rms_norm"):
553
+ v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
554
+ T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
555
+ T.writes(rms_norm_1[v_bsz, v_i, v_k])
556
+ rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
557
+
558
+ @T.prim_func
559
+ def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, m: T.int64):
560
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
561
+ n = T.int64()
562
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
563
+ rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
564
+ # with T.block("root"):
565
+ for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
566
+ with T.block("rotary"):
567
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
568
+ T.reads(B[m + v_i1 - n, v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[m + v_i1 - n, v_i3])
569
+ T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
570
+ rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[m + v_i1 - n, v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
571
+
572
+ @T.prim_func
573
+ def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
574
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
575
+ n = T.int64()
576
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16")
577
+ # with T.block("root"):
578
+ for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
579
+ with T.block("slice"):
580
+ v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
581
+ T.reads(A[v_i, n - T.int64(1), v_k])
582
+ T.writes(slice_1[v_i, v_j, v_k])
583
+ slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k]
584
+
585
+ @T.prim_func
586
+ def squeeze(var_A: T.handle, var_T_squeeze: T.handle):
587
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
588
+ n = T.int64()
589
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
590
+ T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(128)), "float16")
591
+ # with T.block("root"):
592
+ for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(128)):
593
+ with T.block("T_squeeze"):
594
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
595
+ T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
596
+ T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
597
+ T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
598
+
599
+ @T.prim_func
600
+ def take_decode(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), var_C: T.handle, var_take_decode: T.handle):
601
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
602
+ n = T.int64()
603
+ C = T.match_buffer(var_C, (n,), "int32")
604
+ take_decode_1 = T.match_buffer(var_take_decode, (n, T.int64(4096)), "float16")
605
+ # with T.block("root"):
606
+ for i, j in T.grid(n, T.int64(4096)):
607
+ with T.block("take_decode"):
608
+ v_i, v_j = T.axis.remap("SS", [i, j])
609
+ T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
610
+ T.writes(take_decode_1[v_i, v_j])
611
+ take_decode_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
612
+
613
+ @T.prim_func
614
+ def transpose3(var_A: T.handle, var_T_transpose: T.handle):
615
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
616
+ n = T.int64()
617
+ A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
618
+ T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
619
+ # with T.block("root"):
620
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)):
621
+ with T.block("T_transpose"):
622
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
623
+ T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
624
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
625
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
626
+
627
+ @T.prim_func
628
+ def transpose4(var_A: T.handle, var_T_transpose: T.handle):
629
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
630
+ n = T.int64()
631
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(128)), "float16")
632
+ T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(128)), "float16")
633
+ # with T.block("root"):
634
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)):
635
+ with T.block("T_transpose"):
636
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
637
+ T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
638
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
639
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
640
+ # fmt: on
debug/mod_tir_static.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from tvm.script import ir as I
3
+ from tvm.script import tir as T
4
+
5
+ # fmt: off
6
+ # from tvm.script import ir as I
7
+ # from tvm.script import tir as T
8
+
9
+ @I.ir_module
10
+ class Module:
11
+ @T.prim_func
12
+ def divide(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), B: T.Buffer((), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
13
+ T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
14
+ # with T.block("root"):
15
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
16
+ with T.block("T_divide"):
17
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
18
+ T.reads(A[v_ax0, v_ax1, v_ax2], B[()])
19
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2])
20
+ T_divide[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2] / B[()]
21
+
22
+ @T.prim_func
23
+ def fused_decode4_fused_matmul4_cast2(lv2931: T.Buffer((T.int64(824), T.int64(32000)), "uint16"), lv2932: T.Buffer((T.int64(103), T.int64(32000)), "float16"), lv3152: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
24
+ T.func_attr({"tir.noalias": T.bool(True)})
25
+ # with T.block("root"):
26
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(32000)), "float16")
27
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), "float16")
28
+ for i, j in T.grid(T.int64(4096), T.int64(32000)):
29
+ with T.block("decode"):
30
+ v_i, v_j = T.axis.remap("SS", [i, j])
31
+ T.reads(lv2931[v_i // T.int64(5), v_j], lv2932[v_i // T.int64(40), v_j])
32
+ T.writes(var_decode_intermediate[v_i, v_j])
33
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv2931[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv2932[v_i // T.int64(40), v_j]
34
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(32000), T.int64(4096)):
35
+ with T.block("matmul"):
36
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
37
+ T.reads(lv3152[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
38
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
39
+ with T.init():
40
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
41
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv3152[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
42
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
43
+ with T.block("compute"):
44
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
45
+ T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
46
+ T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
47
+ p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2])
48
+
49
+ @T.prim_func
50
+ def fused_decode5_fused_matmul7_add1(lv1605: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1606: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv197: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1581: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
51
+ T.func_attr({"tir.noalias": T.bool(True)})
52
+ # with T.block("root"):
53
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
54
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
55
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
56
+ with T.block("decode"):
57
+ v_i, v_j = T.axis.remap("SS", [i, j])
58
+ T.reads(lv1605[v_i // T.int64(5), v_j], lv1606[v_i // T.int64(40), v_j])
59
+ T.writes(var_decode_intermediate[v_i, v_j])
60
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1605[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1606[v_i // T.int64(40), v_j]
61
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
62
+ with T.block("matmul"):
63
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
64
+ T.reads(lv197[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
65
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
66
+ with T.init():
67
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
68
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv197[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
69
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
70
+ with T.block("T_add"):
71
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
72
+ T.reads(lv1581[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
73
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
74
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1581[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
75
+
76
+ @T.prim_func
77
+ def fused_decode5_matmul7(lv1587: T.Buffer((T.int64(824), T.int64(4096)), "uint16"), lv1588: T.Buffer((T.int64(103), T.int64(4096)), "float16"), lv1583: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
78
+ T.func_attr({"tir.noalias": T.bool(True)})
79
+ # with T.block("root"):
80
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
81
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
82
+ with T.block("decode"):
83
+ v_i, v_j = T.axis.remap("SS", [i, j])
84
+ T.reads(lv1587[v_i // T.int64(5), v_j], lv1588[v_i // T.int64(40), v_j])
85
+ T.writes(var_decode_intermediate[v_i, v_j])
86
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1587[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1588[v_i // T.int64(40), v_j]
87
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)):
88
+ with T.block("matmul"):
89
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
90
+ T.reads(lv1583[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
91
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
92
+ with T.init():
93
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
94
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1583[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
95
+
96
+ @T.prim_func
97
+ def fused_decode6_fused_matmul9_multiply1(lv1617: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1618: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv3: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
98
+ T.func_attr({"tir.noalias": T.bool(True)})
99
+ # with T.block("root"):
100
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
101
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
102
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
103
+ with T.block("decode"):
104
+ v_i, v_j = T.axis.remap("SS", [i, j])
105
+ T.reads(lv1617[v_i // T.int64(5), v_j], lv1618[v_i // T.int64(40), v_j])
106
+ T.writes(var_decode_intermediate[v_i, v_j])
107
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1617[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1618[v_i // T.int64(40), v_j]
108
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
109
+ with T.block("matmul"):
110
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
111
+ T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
112
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
113
+ with T.init():
114
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
115
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
116
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
117
+ with T.block("T_multiply"):
118
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
119
+ T.reads(lv3[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
120
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
121
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
122
+
123
+ @T.prim_func
124
+ def fused_decode6_fused_matmul9_silu1(lv1611: T.Buffer((T.int64(824), T.int64(11008)), "uint16"), lv1612: T.Buffer((T.int64(103), T.int64(11008)), "float16"), lv1622: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
125
+ T.func_attr({"tir.noalias": T.bool(True)})
126
+ # with T.block("root"):
127
+ var_decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
128
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
129
+ compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
130
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
131
+ with T.block("decode"):
132
+ v_i, v_j = T.axis.remap("SS", [i, j])
133
+ T.reads(lv1611[v_i // T.int64(5), v_j], lv1612[v_i // T.int64(40), v_j])
134
+ T.writes(var_decode_intermediate[v_i, v_j])
135
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1611[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1612[v_i // T.int64(40), v_j]
136
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(11008), T.int64(4096)):
137
+ with T.block("matmul"):
138
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
139
+ T.reads(lv1622[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
140
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
141
+ with T.init():
142
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
143
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1622[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
144
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
145
+ with T.block("compute"):
146
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
147
+ T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2])
148
+ T.writes(compute[v_i0, v_i1, v_i2])
149
+ compute[v_i0, v_i1, v_i2] = T.sigmoid(var_matmul_intermediate[v_i0, v_i1, v_i2])
150
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
151
+ with T.block("T_multiply"):
152
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
153
+ T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
154
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
155
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
156
+
157
+ @T.prim_func
158
+ def fused_decode7_fused_matmul10_add1(lv1623: T.Buffer((T.int64(2208), T.int64(4096)), "uint16"), lv1624: T.Buffer((T.int64(276), T.int64(4096)), "float16"), lv200: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv198: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
159
+ T.func_attr({"tir.noalias": T.bool(True)})
160
+ # with T.block("root"):
161
+ var_decode_intermediate = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
162
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")
163
+ for i, j in T.grid(T.int64(11008), T.int64(4096)):
164
+ with T.block("decode"):
165
+ v_i, v_j = T.axis.remap("SS", [i, j])
166
+ T.reads(lv1623[v_i // T.int64(5), v_j], lv1624[v_i // T.int64(40), v_j])
167
+ T.writes(var_decode_intermediate[v_i, v_j])
168
+ var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", lv1623[v_i // T.int64(5), v_j]), T.Cast("uint32", v_i % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * lv1624[v_i // T.int64(40), v_j]
169
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)):
170
+ with T.block("matmul"):
171
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
172
+ T.reads(lv200[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2])
173
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
174
+ with T.init():
175
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
176
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv200[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2]
177
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
178
+ with T.block("T_add"):
179
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
180
+ T.reads(lv198[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2])
181
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
182
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv198[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
183
+
184
+ @T.prim_func
185
+ def fused_reshape7_squeeze1(lv1591: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_T_squeeze_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
186
+ T.func_attr({"tir.noalias": T.bool(True)})
187
+ # with T.block("root"):
188
+ var_T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
189
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
190
+ with T.block("T_reshape"):
191
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
192
+ T.reads(lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
193
+ T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
194
+ var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1591[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
195
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
196
+ with T.block("T_squeeze"):
197
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
198
+ T.reads(var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2])
199
+ T.writes(var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2])
200
+ var_T_squeeze_intermediate[v_ax0, v_ax1, v_ax2] = var_T_reshape_intermediate[T.int64(0), v_ax0, v_ax1, v_ax2]
201
+
202
+ @T.prim_func
203
+ def fused_transpose7_reshape8(lv1616: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), var_T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
204
+ T.func_attr({"tir.noalias": T.bool(True)})
205
+ # with T.block("root"):
206
+ var_T_transpose_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")
207
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
208
+ with T.block("T_transpose"):
209
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
210
+ T.reads(lv1616[v_ax0, v_ax2, v_ax1, v_ax3])
211
+ T.writes(var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
212
+ var_T_transpose_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv1616[v_ax0, v_ax2, v_ax1, v_ax3]
213
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
214
+ with T.block("T_reshape"):
215
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
216
+ T.reads(var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)])
217
+ T.writes(var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2])
218
+ var_T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = var_T_transpose_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]
219
+
220
+ @T.prim_func
221
+ def reshape5(A: T.Buffer((T.int64(1), T.int64(1)), "int32"), T_reshape: T.Buffer((T.int64(1),), "int32")):
222
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
223
+ # with T.block("root"):
224
+ for ax0 in range(T.int64(1)):
225
+ with T.block("T_reshape"):
226
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
227
+ T.reads(A[T.int64(0), T.int64(0)])
228
+ T.writes(T_reshape[v_ax0])
229
+ T_reshape[v_ax0] = A[T.int64(0), T.int64(0)]
230
+
231
+ @T.prim_func
232
+ def reshape6(A: T.Buffer((T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
233
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
234
+ # with T.block("root"):
235
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
236
+ with T.block("T_reshape"):
237
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
238
+ T.reads(A[T.int64(0), v_ax2 % T.int64(4096)])
239
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
240
+ T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax2 % T.int64(4096)]
241
+
242
+ @T.prim_func
243
+ def reshape7(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16")):
244
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
245
+ # with T.block("root"):
246
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
247
+ with T.block("T_reshape"):
248
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
249
+ T.reads(A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)])
250
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
251
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]
252
+
253
+ @T.prim_func
254
+ def rms_norm1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), B: T.Buffer((T.int64(4096),), "float16"), rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
255
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
256
+ # with T.block("root"):
257
+ Ared_temp = T.alloc_buffer((T.int64(1), T.int64(1)))
258
+ for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
259
+ with T.block("Ared_temp"):
260
+ v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
261
+ T.reads(A[v_bsz, v_i, v_k])
262
+ T.writes(Ared_temp[v_bsz, v_i])
263
+ with T.init():
264
+ Ared_temp[v_bsz, v_i] = T.float32(0)
265
+ Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
266
+ for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
267
+ with T.block("rms_norm"):
268
+ v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
269
+ T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
270
+ T.writes(rms_norm[v_bsz, v_i, v_k])
271
+ rms_norm[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
272
+
273
+ @T.prim_func
274
+ def rotary_embedding1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), n: T.int64):
275
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
276
+ # with T.block("root"):
277
+ for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
278
+ with T.block("rotary"):
279
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
280
+ T.reads(B[n + v_i1 - T.int64(1), v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[n + v_i1 - T.int64(1), v_i3])
281
+ T.writes(rotary[v_i0, v_i1, v_i2, v_i3])
282
+ rotary[v_i0, v_i1, v_i2, v_i3] = B[n + v_i1 - T.int64(1), v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[n + v_i1 - T.int64(1), v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1))
283
+
284
+ @T.prim_func
285
+ def slice1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), slice: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")):
286
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
287
+ # with T.block("root"):
288
+ for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
289
+ with T.block("slice"):
290
+ v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k])
291
+ T.reads(A[v_i, T.int64(0), v_k])
292
+ T.writes(slice[v_i, v_j, v_k])
293
+ slice[v_i, v_j, v_k] = A[v_i, T.int64(0), v_k]
294
+
295
+ @T.prim_func
296
+ def softmax(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")):
297
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
298
+ # with T.block("root"):
299
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1)))
300
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)))
301
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1)))
302
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
303
+ with T.block("T_softmax_maxelem"):
304
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
305
+ T.reads(A[v_i0, v_i1, v_k])
306
+ T.writes(T_softmax_maxelem[v_i0, v_i1])
307
+ with T.init():
308
+ T_softmax_maxelem[v_i0, v_i1] = T.float32(-3.4028234663852886e+38)
309
+ T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k])
310
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
311
+ with T.block("T_softmax_exp"):
312
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
313
+ T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
314
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
315
+ T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
316
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
317
+ with T.block("T_softmax_expsum"):
318
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
319
+ T.reads(T_softmax_exp[v_i0, v_i1, v_k])
320
+ T.writes(T_softmax_expsum[v_i0, v_i1])
321
+ with T.init():
322
+ T_softmax_expsum[v_i0, v_i1] = T.float32(0)
323
+ T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
324
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
325
+ with T.block("T_softmax_norm"):
326
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
327
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1])
328
+ T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
329
+ T.block_attr({"axis": 2})
330
+ T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
331
+
332
+ @T.prim_func
333
+ def squeeze1(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_squeeze: T.Buffer((T.int64(1), T.int64(32), T.int64(128)), "float16")):
334
+ T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)})
335
+ # with T.block("root"):
336
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(128)):
337
+ with T.block("T_squeeze"):
338
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
339
+ T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2])
340
+ T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
341
+ T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2]
342
+
343
+ @T.prim_func
344
+ def take_decode1(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), C: T.Buffer((T.int64(1),), "int32"), take_decode: T.Buffer((T.int64(1), T.int64(4096)), "float16")):
345
+ T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)})
346
+ # with T.block("root"):
347
+ for i, j in T.grid(T.int64(1), T.int64(4096)):
348
+ with T.block("take_decode"):
349
+ v_i, v_j = T.axis.remap("SS", [i, j])
350
+ T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)])
351
+ T.writes(take_decode[v_i, v_j])
352
+ take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)]
353
+
354
+ @T.prim_func
355
+ def transpose6(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
356
+ T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
357
+ # with T.block("root"):
358
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128)):
359
+ with T.block("T_transpose"):
360
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
361
+ T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3])
362
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
363
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3]
364
+ # fmt: on
mod_cache_before_build_android.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61f76a70ba5f4b9b97295c8d8497a2ce557191ef78236c1173d29ab4731775ac
3
+ size 33453240
params/mlc-chat-config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_lib": "vicuna-Wizard-7B-Uncensored-android-q3f16_0",
3
+ "local_id": "vicuna-Wizard-7B-Uncensored-android-q3f16_0",
4
+ "conv_template": "conv_one_shot",
5
+ "temperature": 0.7,
6
+ "repetition_penalty": 1.0,
7
+ "top_p": 0.95,
8
+ "mean_gen_len": 128,
9
+ "max_gen_len": 512,
10
+ "shift_fill_factor": 0.3,
11
+ "tokenizer_files": [
12
+ "tokenizer.json",
13
+ "tokenizer.model"
14
+ ]
15
+ }
params/ndarray-cache.json ADDED
The diff for this file is too large to render. See raw diff
 
params/params_shard_0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05e82cb7f41683ebca9c5dac10c6c6cb3114f64c1cb92174fa1ab2eb91c8bb58
3
+ size 52736000
params/params_shard_1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f475ee094bbe8738202e5b90e34e6d738207ad9d60c80ccf10d779b507f27865
3
+ size 30955008
params/params_shard_10.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90fde15cc917dba6523e863f9ba68a06ffa9a9aa6314aa958639bb35f8bd8652
3
+ size 18141184
params/params_shard_100.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45ffa798effa1f4dcc0e18fed876e630d5158048deab79b0ce001cb35201dd1f
3
+ size 29578240
params/params_shard_101.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:304c07013a1d51ab9044b56ebc9796f4ac9b1710bc209c4cece32f68551e1624
3
+ size 18141184
params/params_shard_102.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46c31d3d4ab8e92ded2da394d808016276379223f728bee7cfc8f0e61a0ad5c5
3
+ size 18141184
params/params_shard_103.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69be2551dd7b84c72230f40566ff587fc0ec7211cca086ef8a0e231b7fca7d23
3
+ size 32643584
params/params_shard_104.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89dae65b57366525d75ef3b4c3f9c0355e381501d41988590e9c26ba4af29932
3
+ size 30210560
params/params_shard_105.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca0dcea3d364a2ecbf5d0342013a1b05682d68fc941c069ed87af1447ca5d801
3
+ size 18141184
params/params_shard_106.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fee3ea256c0ae4d22230451a2dac4924975b715daad071f15734d4bf3bc88290
3
+ size 18141184
params/params_shard_107.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dfd94eea61cc08152dd9ffd799a2fbc01eb9b528e7109e80b4746119fe252eb
3
+ size 18087936
params/params_shard_108.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43caa8e093339b27fd7d269972a04f0e6a3528b40ce7dbcffa1845e0b0202ed4
3
+ size 29578240
params/params_shard_109.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ea8b8e5a91f56582dcce624e269c5d93f9ca921949711967d3b884e50d3c13f
3
+ size 18141184
params/params_shard_11.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bc9fa98e51b0473c5eec445fb8816069e0f36455e7026cfb970a9268e70e80b
3
+ size 18087936
params/params_shard_110.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79ddd8fafc513a13bf2cccd87349a5af7b113c78c44ea8190a74b2c992cab689
3
+ size 18141184
params/params_shard_111.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e731157580eaa87e96639a3f2b754259030a395360bf359125923844559ac6c6
3
+ size 32643584
params/params_shard_112.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e37507121ee529dcf48c793649622a63ed54b09f06b2fa41ff07405f61de6e2
3
+ size 30210560
params/params_shard_113.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7220a590e7a69a483487510e9b40144965611a5f0900ba9485cc6f6efc07b29
3
+ size 18141184
params/params_shard_114.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b8edc3a8867cb78f134e6a8658fc50204b5532a601bd7fcdcae4ad7883104b0
3
+ size 18141184
params/params_shard_115.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b9c50d47d81180c3fc15f01cc7e2adcf3c37f8f6465eab91395d4fa7f0e8f33
3
+ size 18087936
params/params_shard_116.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2f60e485c53239dc06cc00de12eed418233c4252229a9083add27b87c067147
3
+ size 29578240
params/params_shard_117.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:212d6a7e47c4bf4fcff8833ca24e0d9511802adee6b9cae168b4ff0055282cd3
3
+ size 18141184
params/params_shard_118.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65aa7f26130e465fb9444786e399b04c84c220cb23f9f1cd855ca7a14464be88
3
+ size 18141184
params/params_shard_119.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f79cdedfcfd480c63c21017cbe1097fb6e6c1d2f8df7b0600efe466a5430cc29
3
+ size 32643584
params/params_shard_12.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:190f3293399e9520a5efe458f5a9288e03e2faf5adc538bf35d31d9b4673d5e1
3
+ size 29578240
params/params_shard_120.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d01cbd2eb461c329dfc89c9826bca789d4fb5b2b4afcb48f66c81f41d7408abe
3
+ size 30210560
params/params_shard_121.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:794717e19321f502f178bda5d8ea616fb48c19907c2f8747a8d50fd20b9d7f77
3
+ size 18141184
params/params_shard_122.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bf6c6bc03172f2a58beea56dce92aa2ed933358e7a97e98bfe06e8b747cff8f
3
+ size 18141184
params/params_shard_123.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76d3cf47f15deda8a3598fecab044153e11c4fe88a2c6bc58e65fa2f8cdddc7c
3
+ size 18087936
params/params_shard_124.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8db9fbc83e5a9fe7d4e294540efaefb325ba5573bdc0fb20ea6c715eaaf3f7a0
3
+ size 29578240
params/params_shard_125.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8f2963ec1e598e6ee85a91c21dce132978347a4abc812b444b9f900ea3d45b6
3
+ size 18141184
params/params_shard_126.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8fb390ac64f23f8782d141f6f72e11b5aa9b51bb211fb16f4b3a61bf3203784
3
+ size 18141184
params/params_shard_127.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:012dd36b4436148eaae04869543827c5e49c29257fe383ecdd7e8bfa257d00cf
3
+ size 32643584
params/params_shard_128.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d3304f8d32e19db6f4b19c07893ad06558cddf74f19de8c9d02da2418b23829
3
+ size 52736000
params/params_shard_129.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97a4dd35920543a930cef9887f429cc7a7cc73aff0ddd010e9a9407e23efd825
3
+ size 29208576
params/params_shard_13.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83d54fb1a8389d74cf9068c2c363372649ac0b46c9c15f9632b8c17bd9f77c00
3
+ size 18141184
params/params_shard_14.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60e75e756b9d83d6e5b522144a9f958d1c03ee858f8fad5e4c9e6312cd391763
3
+ size 18141184
params/params_shard_15.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eb0e11ce7c95f3ce42192fc27a83674e8e5da6f328353f8e7a269a952c92502
3
+ size 32643584
params/params_shard_16.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8375ce114ca4e3dbf33f39cb697f2cd5b9fbf62a05a48a12e0ac40e1d782929
3
+ size 30210560
params/params_shard_17.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:537be9cd347dac60a3cae2fde68c8b0fdc5c488508164aa86fc5f3fe7938f67e
3
+ size 18141184
params/params_shard_18.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e8a3464fef817a62888684d12c56dfad343f79a68825c83278460c562a8f95c
3
+ size 18141184
params/params_shard_19.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1d615fc53d422efa5d54033a79c34e0ca66b52bacc8b30d75163de2bd9821ec
3
+ size 18087936
params/params_shard_2.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a5d64f7b0ff1fe6315b220755b175a2f4641e0e66591d7b25cf341d88b65afa
3
+ size 18141184
params/params_shard_20.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f977dce5c6b01dda9e9c6b7524ce35cf5adc485cab682c7340b2fa06a231aa77
3
+ size 29578240