from tvm.script import ir as I from tvm.script import tir as T # fmt: off # from tvm.script import ir as I # from tvm.script import tir as T @I.ir_module class Module: @T.prim_func def NT_matmul1(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") NT_matmul = T.match_buffer(var_NT_matmul, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) T.writes(NT_matmul[v_i0, v_i1, v_i2]) with T.init(): NT_matmul[v_i0, v_i1, v_i2] = T.float16(0) NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] @T.prim_func def extend_te(var_A: T.handle, var_concat_te: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(1), n, n), "float16") m = T.int64() concat_te = T.match_buffer(var_concat_te, (T.int64(1), T.int64(1), n, m), "float16") # with T.block("root"): for b, _, i, j in T.grid(T.int64(1), T.int64(1), n, m): with T.block("concat_te"): v_b, v__, v_i, v_j = T.axis.remap("SSSS", [b, _, i, j]) T.reads(A[v_b, v__, v_i, v_j + n - m]) T.writes(concat_te[v_b, v__, v_i, v_j]) concat_te[v_b, v__, v_i, v_j] = T.if_then_else(v_j < m - n, T.float16(65504), A[v_b, v__, v_i, v_j + n - m]) @T.prim_func def full(var_T_full: T.handle): T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)}) n = T.int64() T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(1), T.int64(1), n), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), n): with T.block("T_full"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads() T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(65504) @T.prim_func def fused_NT_matmul1_add1(p_lv41: T.handle, lv1386: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv41 = T.match_buffer(p_lv41, (T.int64(1), n, T.int64(4096)), "float16") lv2 = T.match_buffer(p_lv2, (T.int64(1), n, T.int64(4096)), "float16") var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv41[v_i0, v_i1, v_k], lv1386[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv41[v_i0, v_i1, v_k] * lv1386[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv2[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] @T.prim_func def fused_NT_matmul2_divide2_maximum1_minimum1_cast3(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") m = T.int64() lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float16") var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) # with T.block("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m), "float16") for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) with T.init(): var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_maximum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_minimum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) @T.prim_func def fused_NT_matmul3_multiply1(p_lv45: T.handle, lv1400: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_lv50: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") lv50 = T.match_buffer(p_lv50, (T.int64(1), n, T.int64(11008)), "float16") var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16") # with T.block("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv45[v_i0, v_i1, v_k], lv1400[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv1400[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv50[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = lv50[v_ax0, v_ax1, v_ax2] * var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] @T.prim_func def fused_NT_matmul3_silu1(p_lv45: T.handle, lv1393: T.Buffer((T.int64(11008), T.int64(4096)), "float16"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv45 = T.match_buffer(p_lv45, (T.int64(1), n, T.int64(4096)), "float16") var_T_multiply_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(11008)), "float16") # with T.block("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") compute = T.alloc_buffer((T.int64(1), n, T.int64(11008)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(11008), T.int64(4096)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv45[v_i0, v_i1, v_k], lv1393[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv45[v_i0, v_i1, v_k] * lv1393[v_i2, v_k] for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(11008)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) compute[v_i0, v_i1, v_i2] = T.sigmoid(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(11008)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2]) T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2] @T.prim_func def fused_NT_matmul4_add1(p_lv51: T.handle, lv1407: T.Buffer((T.int64(4096), T.int64(11008)), "float16"), p_lv44: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv51 = T.match_buffer(p_lv51, (T.int64(1), n, T.int64(11008)), "float16") lv44 = T.match_buffer(p_lv44, (T.int64(1), n, T.int64(4096)), "float16") var_T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), T.int64(11008)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv51[v_i0, v_i1, v_k], lv1407[v_i2, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv51[v_i0, v_i1, v_k] * lv1407[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(lv44[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = lv44[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] @T.prim_func def fused_NT_matmul_divide1_maximum_minimum_cast(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16"), p_lv1606: T.handle, p_lv1582: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv1606 = T.match_buffer(p_lv1606, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") lv1582 = T.match_buffer(p_lv1582, (T.int64(1), T.int64(1), T.int64(1), n), "float16") var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) # with T.block("root"): var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") var_T_minimum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n), "float16") for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(lv1605[v_i0, v_i1, v_i2, v_k], lv1606[v_i0, v_i1, v_i3, v_k]) T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) with T.init(): var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1605[v_i0, v_i1, v_i2, v_k] * lv1606[v_i0, v_i1, v_i3, v_k] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_divide"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_maximum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_minimum"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1582[v_ax0, T.int64(0), v_ax2, v_ax3]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) @T.prim_func def fused_min_max_triu_te_broadcast_to(p_output0: T.handle, n: T.int64): T.func_attr({"tir.noalias": T.bool(True)}) var_T_broadcast_to_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), n, n), "float16") # with T.block("root"): var_make_diag_mask_te_intermediate = T.alloc_buffer((n, n), "float16") for i, j in T.grid(n, n): with T.block("make_diag_mask_te"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads() T.writes(var_make_diag_mask_te_intermediate[v_i, v_j]) var_make_diag_mask_te_intermediate[v_i, v_j] = T.Select(v_i < v_j, T.float16(-65504), T.float16(65504)) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), n, n): with T.block("T_broadcast_to"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(var_make_diag_mask_te_intermediate[v_ax2, v_ax3]) T.writes(var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) var_T_broadcast_to_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_make_diag_mask_te_intermediate[v_ax2, v_ax3] @T.prim_func def fused_softmax1_cast1(p_lv1613: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n = T.int64() lv1613 = T.match_buffer(p_lv1613, (T.int64(1), T.int64(32), T.int64(1), n)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n), "float16") # with T.block("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1))) var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv1613[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) with T.init(): T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv1613[v_i0, v_i1, v_i2, v_k]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(lv1613[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv1613[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) with T.init(): T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.block_attr({"axis": 3}) var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) @T.prim_func def fused_softmax2_cast4(p_lv36: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) n, m = T.int64(), T.int64() lv36 = T.match_buffer(p_lv36, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") # with T.block("root"): T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n)) T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n)) var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_softmax_maxelem"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(lv36[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2]) with T.init(): T_softmax_maxelem[v_i0, v_i1, v_i2] = T.float32(-3.4028234663852886e+38) T_softmax_maxelem[v_i0, v_i1, v_i2] = T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv36[v_i0, v_i1, v_i2, v_k]) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_softmax_exp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(lv36[v_i0, v_i1, v_i2, v_i3], T_softmax_maxelem[v_i0, v_i1, v_i2]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3]) T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv36[v_i0, v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2]) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_softmax_expsum"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1, v_i2]) with T.init(): T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0) T_softmax_expsum[v_i0, v_i1, v_i2] = T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("T_softmax_norm"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], T_softmax_expsum[v_i0, v_i1, v_i2]) T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.block_attr({"axis": 3}) var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3]) @T.prim_func def matmul10(var_A: T.handle, var_B: T.handle, var_matmul: T.handle): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) n, m = T.int64(), T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, m), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), m, T.int64(128)), "float16") matmul = T.match_buffer(var_matmul, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") # with T.block("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, T.int64(128), m): with T.block("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) with T.init(): matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] @T.prim_func def matmul5(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") # with T.block("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): with T.block("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) with T.init(): matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0) matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3] @T.prim_func def reshape3(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (n, T.int64(32), T.int64(128)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[((v_ax3 // T.int64(128) + v_ax2) // T.int64(32) + v_ax0 * n + v_ax1) % n, (v_ax3 // T.int64(128) + v_ax2) % T.int64(32), v_ax3 % T.int64(128)] @T.prim_func def reshape5(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n), "int32") T_reshape = T.match_buffer(var_T_reshape, (n,), "int32") # with T.block("root"): for ax0 in range(n): with T.block("T_reshape"): v_ax0 = T.axis.spatial(n, ax0) T.reads(A[T.int64(0), v_ax0 % n]) T.writes(T_reshape[v_ax0]) T_reshape[v_ax0] = A[T.int64(0), v_ax0 % n] @T.prim_func def reshape6(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (n, T.int64(4096)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = A[(v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096)] @T.prim_func def reshape7(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(4096) + v_ax0 * n + v_ax1) % n, (v_ax2 * T.int64(128) + v_ax3) % T.int64(4096)] @T.prim_func def reshape8(var_A: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = A[T.int64(0), (v_ax2 // T.int64(4096) + v_ax0 * n + v_ax1) % n, v_ax2 % T.int64(4096) // T.int64(128), v_ax2 % T.int64(128)] @T.prim_func def rms_norm(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): Ared_temp = T.alloc_buffer((T.int64(1), n)) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): with T.block("Ared_temp"): v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k]) T.reads(A[v_bsz, v_i, v_k]) T.writes(Ared_temp[v_bsz, v_i]) with T.init(): Ared_temp[v_bsz, v_i] = T.float32(0) Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k]) for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)): with T.block("rms_norm"): v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k]) T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i]) T.writes(rms_norm_1[v_bsz, v_i, v_k]) rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + T.float32(9.9999999999999995e-07)))) @T.prim_func def rotary_embedding(var_A: T.handle, B: T.Buffer((T.int64(2048), T.int64(128)), "float16"), C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, m: T.int64): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") rotary = T.match_buffer(var_rotary, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") # with T.block("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): with T.block("rotary"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(B[m + v_i1 - n, v_i3], A[v_i0, v_i1, v_i2, v_i3 - T.int64(64):v_i3 - T.int64(64) + T.int64(129)], C[m + v_i1 - n, v_i3]) T.writes(rotary[v_i0, v_i1, v_i2, v_i3]) rotary[v_i0, v_i1, v_i2, v_i3] = B[m + v_i1 - n, v_i3] * A[v_i0, v_i1, v_i2, v_i3] + C[m + v_i1 - n, v_i3] * T.Select(T.int64(64) <= v_i3, A[v_i0, v_i1, v_i2, v_i3 - T.int64(64)], A[v_i0, v_i1, v_i2, v_i3 + T.int64(64)] * T.float16(-1)) @T.prim_func def slice(var_A: T.handle, slice_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") # with T.block("root"): for i, j, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): with T.block("slice"): v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k]) T.reads(A[v_i, n - T.int64(1), v_k]) T.writes(slice_1[v_i, v_j, v_k]) slice_1[v_i, v_j, v_k] = A[v_i, n - T.int64(1), v_k] @T.prim_func def squeeze1(var_A: T.handle, var_T_squeeze: T.handle): T.func_attr({"op_pattern": 1, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") T_squeeze = T.match_buffer(var_T_squeeze, (n, T.int64(32), T.int64(128)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(n, T.int64(32), T.int64(128)): with T.block("T_squeeze"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(A[T.int64(0), v_ax0, v_ax1, v_ax2]) T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) T_squeeze[v_ax0, v_ax1, v_ax2] = A[T.int64(0), v_ax0, v_ax1, v_ax2] @T.prim_func def take_decode1(A: T.Buffer((T.int64(32000), T.int64(824)), "uint16"), B: T.Buffer((T.int64(32000), T.int64(103)), "float16"), var_C: T.handle, var_take_decode: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) n = T.int64() C = T.match_buffer(var_C, (n,), "int32") take_decode = T.match_buffer(var_take_decode, (n, T.int64(4096)), "float16") # with T.block("root"): for i, j in T.grid(n, T.int64(4096)): with T.block("take_decode"): v_i, v_j = T.axis.remap("SS", [i, j]) T.reads(A[C[v_i], v_j // T.int64(5)], C[v_i], B[C[v_i], v_j // T.int64(40)]) T.writes(take_decode[v_i, v_j]) take_decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(T.Cast("uint32", A[C[v_i], v_j // T.int64(5)]), T.Cast("uint32", v_j % T.int64(5)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[C[v_i], v_j // T.int64(40)] @T.prim_func def transpose4(var_A: T.handle, var_T_transpose: T.handle): T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, T.int64(128)): with T.block("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] @T.prim_func def transpose7(var_A: T.handle, var_T_transpose: T.handle): T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") T_transpose = T.match_buffer(var_T_transpose, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), n, T.int64(32), T.int64(128)): with T.block("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(A[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax2, v_ax1, v_ax3] # fmt: on