# from tvm.script import ir as I # from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module class Module: I.module_attrs({"system_lib_prefix": "gpt_neox_q4f16_1_"}) @T.prim_func def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_seq = T.int32(is_size_var=True) seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") # with T.block("root"): for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 255) // 256, thread="blockIdx.x"): for fused_s_v_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("block"): vs = T.axis.spatial(num_seq, (fused_s_v_0 * 256 + fused_s_v_1) // vocab_size) vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 256 + fused_s_v_1) % vocab_size) T.where(fused_s_v_0 * 256 + fused_s_v_1 < num_seq * vocab_size) T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv]) T.writes(logits[seq_ids[vs], vv]) logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0)) @T.prim_func def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_token = T.int32(is_size_var=True) pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") logit_bias = T.match_buffer(var_logit_bias, (num_token,)) # with T.block("root"): for p0 in T.thread_binding((num_token + 255) // 256, thread="blockIdx.x"): for p1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("block"): vp = T.axis.spatial(num_token, p0 * 256 + p1) T.where(p0 * 256 + p1 < num_token) T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp]) T.writes(logits[pos2seq_id[vp], token_ids[vp]]) logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp] @T.prim_func def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_seq = T.int32(is_size_var=True) seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") num_token = T.int32(is_size_var=True) pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") penalties = T.match_buffer(var_penalties, (num_seq, 3)) # with T.block("root"): for p0 in T.thread_binding((num_token + 255) // 256, thread="blockIdx.x"): for p1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("block"): vp = T.axis.spatial(num_token, p0 * 256 + p1) T.where(p0 * 256 + p1 < num_token) T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp]) T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]) logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1]) logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2]) @T.prim_func def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) B = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, 8, 256), "float16") max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(pages_handle, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) output = T.match_buffer(output_handle, (B, 8, 256), "float16") lse = T.match_buffer(lse_handle, (B, 8)) # with T.block("root"): sm_scale: T.float32 = T.float32(0.090168440055560212) for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(1, thread="threadIdx.y"): for tx in T.thread_binding(64, thread="threadIdx.x"): for tz in T.thread_binding(4, thread="threadIdx.z"): with T.block("attn"): T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, fused_by_bz // 8 + ty + fused_by_bz % 8, tx * 4 - 128:tx * 4 - 128 + 260]) T.writes(output[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty]) Q_local = T.alloc_buffer((4,), "float16", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((4, 256), "float16", scope="shared") V_smem = T.alloc_buffer((4, 256), "float16", scope="shared") O_allreduce = T.alloc_buffer((4, 1, 256), scope="shared") md_allreduce = T.alloc_buffer((4, 1, 2), scope="shared") S_reduce_local = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") S_local = T.alloc_buffer((1,), scope="local") QK_local = T.alloc_buffer((4,), scope="local") V_local = T.alloc_buffer((4,), "float16", scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_prev = T.alloc_buffer((1,), scope="local") other_m = T.alloc_buffer((1,), scope="local") other_d = T.alloc_buffer((1,), scope="local") exp_mprev = T.alloc_buffer((1,), scope="local") exp_otherm = T.alloc_buffer((1,), scope="local") other_o = T.alloc_buffer((4,), scope="local") st_m = T.alloc_buffer((1,), scope="local") st_d = T.alloc_buffer((1,), scope="local") O_local = T.alloc_buffer((4,), scope="local") by: T.int32 = fused_by_bz % 8 bz: T.int32 = fused_by_bz // 8 batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0) st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for vec in T.vectorized(4): freq = T.float32() Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, Q[bx, by + bz + ty, tx * 4 + vec + 128] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), Q[bx, by + bz + ty, tx * 4 + vec]) for iterator in range((kv_chunk_len[0] + 3) // 4): tile_start_s: T.int32 = tz + ty tile_start_g: T.int32 = iterator * 4 + tz + ty for j in range(1): with T.block("KV_load"): T.reads() T.writes() row_g: T.int32 = tile_start_g + j if row_g < kv_chunk_len[0]: seq_offset: T.int32 = row_g page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 for vec in T.vectorized(4): freq = T.float32() K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, pages[page_no, 0, by, page_offset, tx * 4 + vec + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec]) V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] else: for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) T.tvm_storage_sync("shared") m_prev[0] = st_m[0] for j in range(1): for vec in T.vectorized(4): QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale S_reduce_local[0] = T.float32(0.0) for vec in T.unroll(4): S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) S_local[j] = T.float32(-50000.0) if iterator * 4 + tz + j < kv_chunk_len[0]: S_local[j] = t0[0] st_m[0] = T.max(st_m[0], S_local[j]) o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) st_d[0] = st_d[0] * o_scale for j in range(1): S_local[j] = T.exp2(S_local[j] - st_m[0]) st_d[0] = st_d[0] + S_local[j] for j in T.vectorized(4): O_local[j] = O_local[j] * o_scale for j in range(1): for vec in T.vectorized(4): V_local[vec] = V_smem[tz + j, tx * 4 + vec] for vec in T.vectorized(4): O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] for vec in T.vectorized(4): O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] md_allreduce[tz, ty, 0] = st_m[0] md_allreduce[tz, ty, 1] = st_d[0] T.tvm_storage_sync("shared") st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for j in range(4): m_prev[0] = st_m[0] d_prev[0] = st_d[0] other_m[0] = md_allreduce[j, ty, 0] other_d[0] = md_allreduce[j, ty, 1] for vec in T.vectorized(4): other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) for vec in T.vectorized(4): O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] for vec in T.vectorized(4): O_local[vec] = O_local[vec] / st_d[0] for vec in T.vectorized(4): output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) @T.prim_func def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) B = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, 8, 256), "float16") max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(pages_handle, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) output = T.match_buffer(output_handle, (B, 8, 256), "float16") lse = T.match_buffer(lse_handle, (B, 8)) # with T.block("root"): sm_scale: T.float32 = T.float32(0.090168440055560212) for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(1, thread="threadIdx.y"): for tx in T.thread_binding(64, thread="threadIdx.x"): for tz in T.thread_binding(4, thread="threadIdx.z"): with T.block("attn"): T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, fused_by_bz // 8 + ty + fused_by_bz % 8, tx * 4 - 128:tx * 4 - 128 + 260]) T.writes(output[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty]) Q_local = T.alloc_buffer((4,), "float16", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((4, 256), "float16", scope="shared") V_smem = T.alloc_buffer((4, 256), "float16", scope="shared") O_allreduce = T.alloc_buffer((4, 1, 256), scope="shared") md_allreduce = T.alloc_buffer((4, 1, 2), scope="shared") S_reduce_local = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") S_local = T.alloc_buffer((1,), scope="local") QK_local = T.alloc_buffer((4,), scope="local") V_local = T.alloc_buffer((4,), "float16", scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_prev = T.alloc_buffer((1,), scope="local") other_m = T.alloc_buffer((1,), scope="local") other_d = T.alloc_buffer((1,), scope="local") exp_mprev = T.alloc_buffer((1,), scope="local") exp_otherm = T.alloc_buffer((1,), scope="local") other_o = T.alloc_buffer((4,), scope="local") st_m = T.alloc_buffer((1,), scope="local") st_d = T.alloc_buffer((1,), scope="local") O_local = T.alloc_buffer((4,), scope="local") by: T.int32 = fused_by_bz % 8 bz: T.int32 = fused_by_bz // 8 batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0) st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for vec in T.vectorized(4): freq = T.float32() Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, Q[bx, by + bz + ty, tx * 4 + vec + 128] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), Q[bx, by + bz + ty, tx * 4 + vec]) for iterator in range((kv_chunk_len[0] + 3) // 4): tile_start_s: T.int32 = tz + ty tile_start_g: T.int32 = iterator * 4 + tz + ty for j in range(1): with T.block("KV_load"): T.reads() T.writes() row_g: T.int32 = tile_start_g + j if row_g < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx]) page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 for vec in T.vectorized(4): freq = T.float32() K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, pages[page_no, 0, by, page_offset, tx * 4 + vec + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec]) V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] else: for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) T.tvm_storage_sync("shared") m_prev[0] = st_m[0] for j in range(1): for vec in T.vectorized(4): QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale S_reduce_local[0] = T.float32(0.0) for vec in T.unroll(4): S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) S_local[j] = T.float32(-50000.0) if iterator * 4 + tz + j < kv_chunk_len[0]: S_local[j] = t0[0] st_m[0] = T.max(st_m[0], S_local[j]) o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) st_d[0] = st_d[0] * o_scale for j in range(1): S_local[j] = T.exp2(S_local[j] - st_m[0]) st_d[0] = st_d[0] + S_local[j] for j in T.vectorized(4): O_local[j] = O_local[j] * o_scale for j in range(1): for vec in T.vectorized(4): V_local[vec] = V_smem[tz + j, tx * 4 + vec] for vec in T.vectorized(4): O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] for vec in T.vectorized(4): O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] md_allreduce[tz, ty, 0] = st_m[0] md_allreduce[tz, ty, 1] = st_d[0] T.tvm_storage_sync("shared") st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for j in range(4): m_prev[0] = st_m[0] d_prev[0] = st_d[0] other_m[0] = md_allreduce[j, ty, 0] other_d[0] = md_allreduce[j, ty, 1] for vec in T.vectorized(4): other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) for vec in T.vectorized(4): O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] for vec in T.vectorized(4): O_local[vec] = O_local[vec] / st_d[0] for vec in T.vectorized(4): output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) @T.prim_func def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (total_len, 8)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((16, 256), "float16", scope="shared") V_smem = T.alloc_buffer((16, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 16), scope="shared") S_local = T.alloc_buffer((16, 16), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 freq = T.float32() K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, pages[page_no, 0, by, page_offset, j + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, j]) else: K_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:16, 0:256]) T.writes(S_local[0:16, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(1, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 1, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(1, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 16: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:16], V_smem[0:16, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (total_len, 8)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((16, 256), "float16", scope="shared") V_smem = T.alloc_buffer((16, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 16), scope="shared") S_local = T.alloc_buffer((16, 16), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 freq = T.float32() K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, pages[page_no, 0, by, page_offset, j + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, j]) else: K_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:16, 0:256]) T.writes(S_local[0:16, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(1, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 1, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(1, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 16: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:16], V_smem[0:16, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) qo_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) kv_len = T.int32(is_size_var=True) k = T.match_buffer(var_k, (kv_len, 8, 256), "float16") v = T.match_buffer(var_v, (kv_len, 8, 256), "float16") kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) output = T.match_buffer(var_output, (qo_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (qo_len, 8)) # with T.block("root"): for lbx in T.thread_binding(8, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((256, 16), "float16", scope="shared") V_smem = T.alloc_buffer((16, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 16), scope="shared") S_local = T.alloc_buffer((16, 16), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] q_indptr_val: T.int32 = q_indptr[b_idx] LH_start: T.int32 = tile_id[0] * 16 kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1 in range(4): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1_0 * 8 + lj_1_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_0, lj_0_0 in T.grid(2, 2): for lj_1 in T.vectorized(8): with T.block("Q_load"): i = T.axis.spatial(16, li_0 * 8 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 16) j = T.axis.spatial(256, lj_0_0 * 128 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 16 * 8 + lj_1) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 L_kv_base: T.int32 = kv_indptr[b_idx] for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lz_0, ly_0_0 in T.grid(2, 2): for ly_1 in T.vectorized(8): with T.block("K_load"): i = T.axis.spatial(16, lz_0 * 8 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 16) j = T.axis.spatial(256, ly_0_0 * 128 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 16 * 8 + ly_1) T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 128:j - 128 + 257]) T.writes(K_smem[j, i]) cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: freq = T.float32() K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, k[L_kv_base + cur_L, by, j + 128] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), k[L_kv_base + cur_L, by, j]) else: K_smem[j, i] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lz_0, ly_0_0 in T.grid(2, 2): for ly_1 in T.vectorized(8): with T.block("V_load"): i = T.axis.spatial(16, lz_0 * 8 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 16) j = T.axis.spatial(256, ly_0_0 * 128 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 16 * 8 + ly_1) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: V_smem[i, j] = v[L_kv_base + cur_L, by, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:256, 0:16]) T.writes(S_local[0:16, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init in T.unroll(1): for lj_1_0_init in T.unroll(1): for lj_1_1_init in T.vectorized(2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_0_init * 2 + lj_1_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0 in range(16): for li_1 in T.unroll(1): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(2): for lk_1 in range(16): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1_0 * 2 + lj_1_1) k_1 = T.axis.reduce(256, lk_0 * 16 + lk_1) T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1 in range(1): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1_0 * 2 + lj_1_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 16: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:16], V_smem[0:16, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init in T.unroll(4): for lj_1_0_init in T.unroll(1): for lj_1_1_init in T.vectorized(8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_0_init * 8 + lj_1_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1 in T.grid(1, 16): for li_1 in T.unroll(4): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1_0 * 8 + lj_1_1) k_1 = T.axis.reduce(16, lk_0 * 16 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1 in range(4): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1_0 * 8 + lj_1_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 8 @T.prim_func def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) qo_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, 8, 256), "float16") q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) kv_len = T.int32(is_size_var=True) k = T.match_buffer(var_k, (kv_len, 8, 256), "float16") v = T.match_buffer(var_v, (kv_len, 8, 256), "float16") kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1) tree_size = T.int32(is_size_var=True) mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1) output = T.match_buffer(var_output, (qo_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (qo_len, 8)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((32, 256), "float16", scope="shared") V_smem = T.alloc_buffer((32, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 32), scope="shared") S_local = T.alloc_buffer((16, 32), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 31) // 32): L_kv_start: T.int32 = iterator_1 * 32 L_kv_base: T.int32 = kv_indptr[b_idx] for lz_ly_fused_0 in range(16): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("KV_load"): i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_base + L_kv_start + i if L_kv_start + i < kv_chunk_len[0]: freq = T.float32() K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, k[cur_L, by, j + 128] * T.float16(-1.0), k[cur_L, by, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), k[cur_L, by, j]) V_smem[i, j] = v[cur_L, by, j] else: K_smem[i, j] = T.float16(0.0) V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:32, 0:256]) T.writes(S_local[0:16, 0:32]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 2 + li_1_init) j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) k_1 = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(32): if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i]) T.writes(S_smem[row, 0:32]) for j in range(32): if row < 16: row_: T.int32 = LH_start + row if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(32): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:32], V_smem[0:32, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle): T.func_attr({"op_pattern": 4, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) temperature = T.match_buffer(var_temperature, (batch_size,)) num_chunks = T.int64(is_size_var=True) chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) # with T.block("root"): A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096))) temp_max = T.alloc_buffer((batch_size, num_chunks)) temp_sum = T.alloc_buffer((batch_size, num_chunks)) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): with T.block("pad"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2]) T.writes(A_pad[v0, v1, v2]) A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0)) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): with T.block("max"): v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) T.reads(A_pad[v0, v1, v2]) T.writes(temp_max[v0, v1]) with T.init(): temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0) temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): with T.block("sum_exp"): v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1]) T.writes(temp_sum[v0, v1]) with T.init(): temp_sum[v0, v1] = T.float32(0.0) temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0)) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): with T.block("log"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1]) T.writes(chunked_sum[v0, v1], chunked_max[v0, v1]) chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1]) chunked_max[v0, v1] = temp_max[v0, v1] @T.prim_func def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) num_pages = T.int32() pages = T.match_buffer(var_pages, (num_pages, 2, 8, 16, 256), "float16", offset_factor=1) copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1) total_copy_length = T.int32() copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1) with T.block("root"): T.reads() T.writes() for bhd_o in T.thread_binding(batch_size * 8, thread="blockIdx.x"): for bhd_i in T.thread_binding(256, thread="threadIdx.x"): b: T.int32 = (bhd_o * 256 + bhd_i) // 2048 h: T.int32 = (bhd_o * 256 + bhd_i) // 256 % 8 d: T.int32 = (bhd_o * 256 + bhd_i) % 256 if bhd_o * 256 + bhd_i < batch_size * 8 * 256: for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]): src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d] pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d] @T.prim_func def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) num_pages, page_size = T.int32(), T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, 8, page_size, 256), "float16", offset_factor=1) # with T.block("root"): for b in T.thread_binding(copy_length * T.int64(8), thread="blockIdx.x"): for t in T.thread_binding(256, thread="threadIdx.x"): with T.block("copy"): vh = T.axis.spatial(8, T.Cast("int32", (b * T.int64(256) + T.Cast("int64", t)) // (copy_length * T.int64(256)))) vp = T.axis.spatial(copy_length, (b * T.int64(256) + T.Cast("int64", t)) % (copy_length * T.int64(256)) // T.int64(256)) vd = T.axis.spatial(256, T.Cast("int32", (b * T.int64(256) + T.Cast("int64", t)) % T.int64(256))) T.where(b * T.int64(256) + T.Cast("int64", t) < copy_length * T.int64(8) * T.int64(256)) T.reads(pages[src_page_id, 0:2, vh, vp, vd]) T.writes(pages[tgt_page_id, 0:2, vh, vp, vd]) pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] @T.prim_func(private=True) def dequantize(var_gpt_neox_embed_in_q_weight: T.handle, var_gpt_neox_embed_in_q_scale: T.handle, var_dequantize: T.handle): T.func_attr({"op_pattern": 2, "target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) vocab_size = T.int64() gpt_neox_embed_in_q_weight = T.match_buffer(var_gpt_neox_embed_in_q_weight, (vocab_size, T.int64(256)), "uint32") gpt_neox_embed_in_q_scale = T.match_buffer(var_gpt_neox_embed_in_q_scale, (vocab_size, T.int64(64)), "float16") dequantize = T.match_buffer(var_dequantize, (vocab_size, T.int64(2048)), "float16") # with T.block("root"): compute = T.alloc_buffer((vocab_size, T.int64(2048)), "float16") for i0, i1 in T.grid(vocab_size, T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_embed_in_q_weight[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_embed_in_q_weight[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(vocab_size, T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_embed_in_q_scale[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_embed_in_q_scale[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize1(gpt_neox_layers_0_attention_query_key_value_q_weight1: T.Buffer((T.int64(6144), T.int64(256)), "uint32"), gpt_neox_layers_0_attention_query_key_value_q_scale1: T.Buffer((T.int64(6144), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(6144), T.int64(2048)), "float16")): T.func_attr({"op_pattern": 2, "target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16") for i0, i1 in T.grid(T.int64(6144), T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_attention_query_key_value_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_attention_query_key_value_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(6144), T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_attention_query_key_value_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_attention_query_key_value_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize2(gpt_neox_layers_0_attention_dense_q_weight1: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), gpt_neox_layers_0_attention_dense_q_scale1: T.Buffer((T.int64(2048), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(2048)), "float16")): T.func_attr({"op_pattern": 2, "target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16") for i0, i1 in T.grid(T.int64(2048), T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_attention_dense_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_attention_dense_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(2048), T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_attention_dense_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_attention_dense_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize3(gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1: T.Buffer((T.int64(8192), T.int64(256)), "uint32"), gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1: T.Buffer((T.int64(8192), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(8192), T.int64(2048)), "float16")): T.func_attr({"op_pattern": 2, "target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(8192), T.int64(2048)), "float16") for i0, i1 in T.grid(T.int64(8192), T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(8192), T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize4(gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1: T.Buffer((T.int64(2048), T.int64(1024)), "uint32"), gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1: T.Buffer((T.int64(2048), T.int64(256)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(8192)), "float16")): T.func_attr({"op_pattern": 2, "target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(2048), T.int64(8192)), "float16") for i0, i1 in T.grid(T.int64(2048), T.int64(8192)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(2048), T.int64(8192)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def fused_NT_matmul10_add10(layer_norm33: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv83: T.Buffer((T.int64(6144), T.int64(2048)), "float16"), gpt_neox_layers_0_attention_query_key_value_bias2: T.Buffer((T.int64(6144),), "float16"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(6144)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(6144)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(6144), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm33[v_i0, v_i1, v_k], lv83[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + layer_norm33[v_i0, v_i1, v_k] * lv83[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(6144)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_attention_query_key_value_bias2[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_attention_query_key_value_bias2[v_ax2] @T.prim_func(private=True) def fused_NT_matmul11_add11_add14_add14(reshape67: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv85: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), gpt_neox_layers_0_attention_dense_bias2: T.Buffer((T.int64(2048),), "float16"), astype34: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), input_embed: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), T_add_intermediate_1_2: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16") T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16") T_add_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(reshape67[v_i0, v_i1, v_k], lv85[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape67[v_i0, v_i1, v_k] * lv85[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_attention_dense_bias2[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_attention_dense_bias2[v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(astype34[v_ax0, v_ax1, v_ax2], T_add_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = astype34[v_ax0, v_ax1, v_ax2] + T_add_intermediate[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("T_add_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate_1[v_ax0, v_ax1, v_ax2], input_embed[v_ax0, v_ax1, v_ax2]) T.writes(T_add_intermediate_1_2[v_ax0, v_ax1, v_ax2]) T_add_intermediate_1_2[v_ax0, v_ax1, v_ax2] = T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + input_embed[v_ax0, v_ax1, v_ax2] @T.prim_func(private=True) def fused_NT_matmul12_add12_gelu2_cast6(layer_norm34: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), lv86: T.Buffer((T.int64(8192), T.int64(2048)), "float16"), gpt_neox_layers_0_mlp_dense_h_to_4h_bias2: T.Buffer((T.int64(8192),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192))) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(8192), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm34[v_i0, v_i1, v_k], lv86[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", layer_norm34[v_i0, v_i1, v_k]) * T.Cast("float32", lv86[v_i2, v_k]) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_mlp_dense_h_to_4h_bias2[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_mlp_dense_h_to_4h_bias2[v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("T_multiply_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(compute[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) T.writes(T_add[v_ax0, v_ax1, v_ax2]) T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("T_multiply_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)): with T.block("compute_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", T_multiply_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul13_add13_cast7(astype33: T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16"), lv87: T.Buffer((T.int64(2048), T.int64(8192)), "float16"), gpt_neox_layers_0_mlp_dense_4h_to_h_bias2: T.Buffer((T.int64(2048),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048))) T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048))) for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(8192)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(astype33[v_i0, v_i1, v_k], lv87[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", astype33[v_i0, v_i1, v_k]) * T.Cast("float32", lv87[v_i2, v_k]) for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_mlp_dense_4h_to_h_bias2[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_mlp_dense_4h_to_h_bias2[v_ax2] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_add_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", T_add_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul14_cast8(layer_norm65: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), p_lv163: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) vocab_size = T.int64() lv163 = T.match_buffer(p_lv163, (vocab_size, T.int64(2048)), "float16") compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), vocab_size)) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16") for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), vocab_size, T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm65[v_i0, v_i1, v_k], lv163[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + layer_norm65[v_i0, v_i1, v_k] * lv163[v_i2, v_k] for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), vocab_size): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul1_add1_add4_add4(p_reshape195: T.handle, lv247: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), gpt_neox_layers_0_attention_dense_bias4: T.Buffer((T.int64(2048),), "float16"), p_astype100: T.handle, p_input_embeds: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size = T.int64() reshape195 = T.match_buffer(p_reshape195, (batch_size, T.int64(1), T.int64(2048)), "float16") astype100 = T.match_buffer(p_astype100, (batch_size, T.int64(1), T.int64(2048)), "float16") input_embeds = T.match_buffer(p_input_embeds, (batch_size, T.int64(1), T.int64(2048)), "float16") T_add_intermediate_1_2 = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)), "float16") T_add_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)), "float16") T_add_intermediate_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)), "float16") for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(reshape195[v_i0, v_i1, v_k], lv247[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape195[v_i0, v_i1, v_k] * lv247[v_i2, v_k] for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_attention_dense_bias4[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_attention_dense_bias4[v_ax2] for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(astype100[v_ax0, v_ax1, v_ax2], T_add_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = astype100[v_ax0, v_ax1, v_ax2] + T_add_intermediate[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("T_add_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate_1[v_ax0, v_ax1, v_ax2], input_embeds[v_ax0, v_ax1, v_ax2]) T.writes(T_add_intermediate_1_2[v_ax0, v_ax1, v_ax2]) T_add_intermediate_1_2[v_ax0, v_ax1, v_ax2] = T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + input_embeds[v_ax0, v_ax1, v_ax2] @T.prim_func(private=True) def fused_NT_matmul2_add2_gelu_cast(p_layer_norm100: T.handle, lv248: T.Buffer((T.int64(8192), T.int64(2048)), "float16"), gpt_neox_layers_0_mlp_dense_h_to_4h_bias4: T.Buffer((T.int64(8192),), "float32"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size = T.int64() layer_norm100 = T.match_buffer(p_layer_norm100, (batch_size, T.int64(1), T.int64(2048)), "float16") compute_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(8192)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) T_add_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) T_multiply = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) compute = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) T_multiply_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) T_add = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) T_multiply_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192))) for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(8192), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm100[v_i0, v_i1, v_k], lv248[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", layer_norm100[v_i0, v_i1, v_k]) * T.Cast("float32", lv248[v_i2, v_k]) for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_mlp_dense_h_to_4h_bias4[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_mlp_dense_h_to_4h_bias4[v_ax2] for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) for i0, i1, i2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("T_multiply_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(compute[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) T.writes(T_add[v_ax0, v_ax1, v_ax2]) T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("T_multiply_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] for i0, i1, i2 in T.grid(batch_size, T.int64(1), T.int64(8192)): with T.block("compute_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", T_multiply_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul3_add3_cast1(p_astype99: T.handle, lv249: T.Buffer((T.int64(2048), T.int64(8192)), "float16"), gpt_neox_layers_0_mlp_dense_4h_to_h_bias4: T.Buffer((T.int64(2048),), "float32"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size = T.int64() astype99 = T.match_buffer(p_astype99, (batch_size, T.int64(1), T.int64(8192)), "float16") compute_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048))) T_add_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048))) for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(8192)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(astype99[v_i0, v_i1, v_k], lv249[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", astype99[v_i0, v_i1, v_k]) * T.Cast("float32", lv249[v_i2, v_k]) for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_mlp_dense_4h_to_h_bias4[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_mlp_dense_4h_to_h_bias4[v_ax2] for i0, i1, i2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_add_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", T_add_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul4_cast2(p_layer_norm131: T.handle, p_lv325: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size = T.int64() layer_norm131 = T.match_buffer(p_layer_norm131, (batch_size, T.int64(1), T.int64(2048)), "float16") vocab_size = T.int64() lv325 = T.match_buffer(p_lv325, (vocab_size, T.int64(2048)), "float16") compute_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), vocab_size)) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), vocab_size), "float16") for i0, i1, i2, k in T.grid(batch_size, T.int64(1), vocab_size, T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm131[v_i0, v_i1, v_k], lv325[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + layer_norm131[v_i0, v_i1, v_k] * lv325[v_i2, v_k] for i0, i1, i2 in T.grid(batch_size, T.int64(1), vocab_size): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul5_add5(p_layer_norm66: T.handle, lv164: T.Buffer((T.int64(6144), T.int64(2048)), "float16"), gpt_neox_layers_0_attention_query_key_value_bias3: T.Buffer((T.int64(6144),), "float16"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm66 = T.match_buffer(p_layer_norm66, (T.int64(1), seq_len, T.int64(2048)), "float16") T_add_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(6144)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(6144)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(6144), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm66[v_i0, v_i1, v_k], lv164[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + layer_norm66[v_i0, v_i1, v_k] * lv164[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(6144)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_attention_query_key_value_bias3[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_attention_query_key_value_bias3[v_ax2] @T.prim_func(private=True) def fused_NT_matmul6_add6_add9_add9(p_reshape131: T.handle, lv166: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), gpt_neox_layers_0_attention_dense_bias3: T.Buffer((T.int64(2048),), "float16"), p_astype67: T.handle, p_input_embeds: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) seq_len = T.int64() reshape131 = T.match_buffer(p_reshape131, (T.int64(1), seq_len, T.int64(2048)), "float16") astype67 = T.match_buffer(p_astype67, (T.int64(1), seq_len, T.int64(2048)), "float16") input_embeds = T.match_buffer(p_input_embeds, (T.int64(1), seq_len, T.int64(2048)), "float16") T_add_intermediate_1_2 = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)), "float16") T_add_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)), "float16") T_add_intermediate_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)), "float16") for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(reshape131[v_i0, v_i1, v_k], lv166[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape131[v_i0, v_i1, v_k] * lv166[v_i2, v_k] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_attention_dense_bias3[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_attention_dense_bias3[v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(astype67[v_ax0, v_ax1, v_ax2], T_add_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = astype67[v_ax0, v_ax1, v_ax2] + T_add_intermediate[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("T_add_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate_1[v_ax0, v_ax1, v_ax2], input_embeds[v_ax0, v_ax1, v_ax2]) T.writes(T_add_intermediate_1_2[v_ax0, v_ax1, v_ax2]) T_add_intermediate_1_2[v_ax0, v_ax1, v_ax2] = T_add_intermediate_1[v_ax0, v_ax1, v_ax2] + input_embeds[v_ax0, v_ax1, v_ax2] @T.prim_func(private=True) def fused_NT_matmul7_add7_gelu1_cast3(p_layer_norm67: T.handle, lv167: T.Buffer((T.int64(8192), T.int64(2048)), "float16"), gpt_neox_layers_0_mlp_dense_h_to_4h_bias3: T.Buffer((T.int64(8192),), "float32"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm67 = T.match_buffer(p_layer_norm67, (T.int64(1), seq_len, T.int64(2048)), "float16") compute_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(8192)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) T_add_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) T_multiply = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) compute = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) T_multiply_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) T_add = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) T_multiply_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192))) for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(8192), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm67[v_i0, v_i1, v_k], lv167[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", layer_norm67[v_i0, v_i1, v_k]) * T.Cast("float32", lv167[v_i2, v_k]) for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_mlp_dense_h_to_4h_bias3[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_mlp_dense_h_to_4h_bias3[v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("T_multiply"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) T_multiply[v_ax0, v_ax1, v_ax2] = T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) for i0, i1, i2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply[v_i0, v_i1, v_i2]) T.writes(compute[v_i0, v_i1, v_i2]) compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("T_multiply_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(compute[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("T_add_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) T.writes(T_add[v_ax0, v_ax1, v_ax2]) T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("T_multiply_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] for i0, i1, i2 in T.grid(T.int64(1), seq_len, T.int64(8192)): with T.block("compute_1"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_multiply_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", T_multiply_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul8_add8_cast4(p_astype66: T.handle, lv168: T.Buffer((T.int64(2048), T.int64(8192)), "float16"), gpt_neox_layers_0_mlp_dense_4h_to_h_bias3: T.Buffer((T.int64(2048),), "float32"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) seq_len = T.int64() astype66 = T.match_buffer(p_astype66, (T.int64(1), seq_len, T.int64(8192)), "float16") compute_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048))) T_add_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048))) for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(8192)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(astype66[v_i0, v_i1, v_k], lv168[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", astype66[v_i0, v_i1, v_k]) * T.Cast("float32", lv168[v_i2, v_k]) for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_mlp_dense_4h_to_h_bias3[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_mlp_dense_4h_to_h_bias3[v_ax2] for i0, i1, i2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_add_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", T_add_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul9_cast5(p_take1: T.handle, p_lv244: T.handle, p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size = T.int64() take1 = T.match_buffer(p_take1, (T.int64(1), batch_size, T.int64(2048)), "float16") vocab_size = T.int64() lv244 = T.match_buffer(p_lv244, (vocab_size, T.int64(2048)), "float16") compute_intermediate = T.match_buffer(p_output0, (T.int64(1), batch_size, vocab_size)) # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((T.int64(1), batch_size, vocab_size), "float16") for i0, i1, i2, k in T.grid(T.int64(1), batch_size, vocab_size, T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(take1[v_i0, v_i1, v_k], lv244[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + take1[v_i0, v_i1, v_k] * lv244[v_i2, v_k] for i0, i1, i2 in T.grid(T.int64(1), batch_size, vocab_size): with T.block("compute"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2]) T.writes(compute_intermediate[v_i0, v_i1, v_i2]) compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2]) @T.prim_func(private=True) def fused_NT_matmul_add(p_layer_norm99: T.handle, lv245: T.Buffer((T.int64(6144), T.int64(2048)), "float16"), gpt_neox_layers_0_attention_query_key_value_bias4: T.Buffer((T.int64(6144),), "float16"), p_output0: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) batch_size = T.int64() layer_norm99 = T.match_buffer(p_layer_norm99, (batch_size, T.int64(1), T.int64(2048)), "float16") T_add_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(6144)), "float16") # with T.block("root"): NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(6144)), "float16") for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(6144), T.int64(2048)): with T.block("NT_matmul"): v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) T.reads(layer_norm99[v_i0, v_i1, v_k], lv245[v_i2, v_k]) T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2]) with T.init(): NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + layer_norm99[v_i0, v_i1, v_k] * lv245[v_i2, v_k] for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(6144)): with T.block("T_add"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], gpt_neox_layers_0_attention_query_key_value_bias4[v_ax2]) T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + gpt_neox_layers_0_attention_query_key_value_bias4[v_ax2] @T.prim_func(private=True) def fused_reshape10_reshape11(lv84: T.Buffer((T.int64(1), T.int64(8), T.int64(256)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8), T.int64(256)), "float16") for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(8), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(lv84[T.int64(0), (v_ax3 // T.int64(256) + v_ax2) % T.int64(8), v_ax3 % T.int64(256)]) T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv84[T.int64(0), (v_ax3 // T.int64(256) + v_ax2) % T.int64(8), v_ax3 % T.int64(256)] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("T_reshape_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2048) // T.int64(256), v_ax2 % T.int64(256)]) T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2]) T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2048) // T.int64(256), v_ax2 % T.int64(256)] @T.prim_func(private=True) def fused_reshape8_reshape9(add96: T.Buffer((T.int64(1), T.int64(1), T.int64(6144)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(24), T.int64(256)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(24), T.int64(256)), "float16") for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(24), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(add96[T.int64(0), T.int64(0), (v_ax2 * T.int64(256) + v_ax3) % T.int64(6144)]) T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = add96[T.int64(0), T.int64(0), (v_ax2 * T.int64(256) + v_ax3) % T.int64(6144)] for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(24), T.int64(256)): with T.block("T_reshape_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(256) + v_ax1) % T.int64(24), v_ax2 % T.int64(256)]) T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2]) T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(256) + v_ax1) % T.int64(24), v_ax2 % T.int64(256)] @T.prim_func def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) seq_len = T.int32() qkv = T.match_buffer(var_qkv, (seq_len, 24, 256), "float16") position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1) q = T.match_buffer(var_q, (seq_len, 8, 256), "float16") k = T.match_buffer(var_k, (seq_len, 8, 256), "float16") v = T.match_buffer(var_v, (seq_len, 8, 256), "float16") # with T.block("root"): for iters_0, iters_1, iters_2 in T.grid(seq_len, 24, 256): with T.block("llama_fused_rope"): s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2]) T.reads(position_map[s], qkv[s, h, d - 32:d - 32 + 65]) T.writes(q[s, h, d], k[s, h - 8, d], v[s, h - 16, d]) if h < 8: freq = T.float32() q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1.0), qkv[s, h, d - 32]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(10000.0), T.Cast("float32", d * 2 % 64) / T.float32(64.0))}), qkv[s, h, d]) else: if h < 16: freq = T.float32() k[s, h - 8, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1.0), qkv[s, h, d - 32]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(10000.0), T.Cast("float32", d * 2 % 64) / T.float32(64.0))}), qkv[s, h, d]) else: v[s, h - 16, d] = qkv[s, h, d] @T.prim_func def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) m, n = T.int32(is_size_var=True), T.int32(is_size_var=True) src = T.match_buffer(var_src, (m, n)) batch_size = T.int32(is_size_var=True) indices = T.match_buffer(var_indices, (batch_size,), "int32") dst = T.match_buffer(var_dst, (batch_size, n)) # with T.block("root"): for b, j in T.grid(batch_size, n): with T.block("gather_2d"): vb, vj = T.axis.remap("SS", [b, j]) T.reads(src[indices[vb], vj], indices[vb]) T.writes(dst[vb, vj]) dst[vb, vj] = src[indices[vb], vj] @T.prim_func(private=True) def index(var_layer_norm32: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")): T.func_attr({"op_pattern": 8, "target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm32 = T.match_buffer(var_layer_norm32, (T.int64(1), seq_len, T.int64(2048)), "float16") # with T.block("root"): for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("index"): v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) T.reads(layer_norm32[v_i, seq_len - T.int64(1), v_k]) T.writes(index[v_i, v__, v_k]) index[v_i, v__, v_k] = layer_norm32[v_i, seq_len - T.int64(1), v_k] @T.prim_func(private=True) def layer_norm(var_input_embeds: T.handle, gpt_neox_layers_0_input_layernorm_weight4: T.Buffer((T.int64(2048),), "float16"), gpt_neox_layers_0_input_layernorm_bias4: T.Buffer((T.int64(2048),), "float16"), var_T_layer_norm: T.handle): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) batch_size = T.int64() input_embeds = T.match_buffer(var_input_embeds, (batch_size, T.int64(1), T.int64(2048)), "float16") T_layer_norm = T.match_buffer(var_T_layer_norm, (batch_size, T.int64(1), T.int64(2048)), "float16") # with T.block("root"): input_embeds_red_temp_v0 = T.alloc_buffer((batch_size, T.int64(1))) input_embeds_red_temp_v1 = T.alloc_buffer((batch_size, T.int64(1))) for ax0, ax1, k2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("input_embeds_red_temp"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(input_embeds[v_ax0, v_ax1, v_k2]) T.writes(input_embeds_red_temp_v0[v_ax0, v_ax1], input_embeds_red_temp_v1[v_ax0, v_ax1]) with T.init(): input_embeds_red_temp_v0[v_ax0, v_ax1] = T.float32(0.0) input_embeds_red_temp_v1[v_ax0, v_ax1] = T.float32(0.0) v_input_embeds_red_temp_v0: T.float32 = input_embeds_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", input_embeds[v_ax0, v_ax1, v_k2]) v_input_embeds_red_temp_v1: T.float32 = input_embeds_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", input_embeds[v_ax0, v_ax1, v_k2]) * T.Cast("float32", input_embeds[v_ax0, v_ax1, v_k2]) input_embeds_red_temp_v0[v_ax0, v_ax1] = v_input_embeds_red_temp_v0 input_embeds_red_temp_v1[v_ax0, v_ax1] = v_input_embeds_red_temp_v1 for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("T_layer_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(input_embeds[v_ax0, v_ax1, v_ax2], input_embeds_red_temp_v0[v_ax0, v_ax1], input_embeds_red_temp_v1[v_ax0, v_ax1], gpt_neox_layers_0_input_layernorm_weight4[v_ax2], gpt_neox_layers_0_input_layernorm_bias4[v_ax2]) T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) T_layer_norm[v_ax0, v_ax1, v_ax2] = T.Cast("float16", (T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2]) - input_embeds_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125)) * T.rsqrt(input_embeds_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00048828125) - input_embeds_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125) * (input_embeds_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125)) + T.float32(1.0000000000000001e-05))) * gpt_neox_layers_0_input_layernorm_weight4[v_ax2] + gpt_neox_layers_0_input_layernorm_bias4[v_ax2] @T.prim_func(private=True) def layer_norm1(var_input_embeds: T.handle, gpt_neox_layers_0_input_layernorm_weight3: T.Buffer((T.int64(2048),), "float16"), gpt_neox_layers_0_input_layernorm_bias3: T.Buffer((T.int64(2048),), "float16"), var_T_layer_norm: T.handle): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) seq_len = T.int64() input_embeds = T.match_buffer(var_input_embeds, (T.int64(1), seq_len, T.int64(2048)), "float16") T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(1), seq_len, T.int64(2048)), "float16") # with T.block("root"): input_embeds_red_temp_v0 = T.alloc_buffer((T.int64(1), seq_len)) input_embeds_red_temp_v1 = T.alloc_buffer((T.int64(1), seq_len)) for ax0, ax1, k2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("input_embeds_red_temp"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(input_embeds[v_ax0, v_ax1, v_k2]) T.writes(input_embeds_red_temp_v0[v_ax0, v_ax1], input_embeds_red_temp_v1[v_ax0, v_ax1]) with T.init(): input_embeds_red_temp_v0[v_ax0, v_ax1] = T.float32(0.0) input_embeds_red_temp_v1[v_ax0, v_ax1] = T.float32(0.0) v_input_embeds_red_temp_v0: T.float32 = input_embeds_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", input_embeds[v_ax0, v_ax1, v_k2]) v_input_embeds_red_temp_v1: T.float32 = input_embeds_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", input_embeds[v_ax0, v_ax1, v_k2]) * T.Cast("float32", input_embeds[v_ax0, v_ax1, v_k2]) input_embeds_red_temp_v0[v_ax0, v_ax1] = v_input_embeds_red_temp_v0 input_embeds_red_temp_v1[v_ax0, v_ax1] = v_input_embeds_red_temp_v1 for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("T_layer_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(input_embeds[v_ax0, v_ax1, v_ax2], input_embeds_red_temp_v0[v_ax0, v_ax1], input_embeds_red_temp_v1[v_ax0, v_ax1], gpt_neox_layers_0_input_layernorm_weight3[v_ax2], gpt_neox_layers_0_input_layernorm_bias3[v_ax2]) T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) T_layer_norm[v_ax0, v_ax1, v_ax2] = T.Cast("float16", (T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2]) - input_embeds_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125)) * T.rsqrt(input_embeds_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00048828125) - input_embeds_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125) * (input_embeds_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125)) + T.float32(1.0000000000000001e-05))) * gpt_neox_layers_0_input_layernorm_weight3[v_ax2] + gpt_neox_layers_0_input_layernorm_bias3[v_ax2] @T.prim_func(private=True) def layer_norm2(input_embed: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), gpt_neox_layers_0_input_layernorm_weight2: T.Buffer((T.int64(2048),), "float16"), gpt_neox_layers_0_input_layernorm_bias2: T.Buffer((T.int64(2048),), "float16"), T_layer_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) # with T.block("root"): input_embed_red_temp_v0 = T.alloc_buffer((T.int64(1), T.int64(1))) input_embed_red_temp_v1 = T.alloc_buffer((T.int64(1), T.int64(1))) for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("input_embed_red_temp"): v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2]) T.reads(input_embed[v_ax0, v_ax1, v_k2]) T.writes(input_embed_red_temp_v0[v_ax0, v_ax1], input_embed_red_temp_v1[v_ax0, v_ax1]) with T.init(): input_embed_red_temp_v0[v_ax0, v_ax1] = T.float32(0.0) input_embed_red_temp_v1[v_ax0, v_ax1] = T.float32(0.0) v_input_embed_red_temp_v0: T.float32 = input_embed_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32", input_embed[v_ax0, v_ax1, v_k2]) v_input_embed_red_temp_v1: T.float32 = input_embed_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32", input_embed[v_ax0, v_ax1, v_k2]) * T.Cast("float32", input_embed[v_ax0, v_ax1, v_k2]) input_embed_red_temp_v0[v_ax0, v_ax1] = v_input_embed_red_temp_v0 input_embed_red_temp_v1[v_ax0, v_ax1] = v_input_embed_red_temp_v1 for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("T_layer_norm"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(input_embed[v_ax0, v_ax1, v_ax2], input_embed_red_temp_v0[v_ax0, v_ax1], input_embed_red_temp_v1[v_ax0, v_ax1], gpt_neox_layers_0_input_layernorm_weight2[v_ax2], gpt_neox_layers_0_input_layernorm_bias2[v_ax2]) T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2]) T_layer_norm[v_ax0, v_ax1, v_ax2] = T.Cast("float16", (T.Cast("float32", input_embed[v_ax0, v_ax1, v_ax2]) - input_embed_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125)) * T.rsqrt(input_embed_red_temp_v1[v_ax0, v_ax1] * T.float32(0.00048828125) - input_embed_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125) * (input_embed_red_temp_v0[v_ax0, v_ax1] * T.float32(0.00048828125)) + T.float32(1.0000000000000001e-05))) * gpt_neox_layers_0_input_layernorm_weight2[v_ax2] + gpt_neox_layers_0_input_layernorm_bias2[v_ax2] @T.prim_func def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True) V = T.match_buffer(v, (N, H, D), "float16") S = T.match_buffer(s, (N, H)) V_other = T.match_buffer(v_other, (N, H, D), "float16") S_other = T.match_buffer(s_other, (N, H)) # with T.block("root"): for bx in T.thread_binding(N, thread="blockIdx.x"): for by in T.thread_binding(2, thread="blockIdx.y"): for ty in T.thread_binding(4, thread="threadIdx.y"): for tx in T.thread_binding(64, thread="threadIdx.x"): with T.block("merge"): T.reads(S[bx, ty + by * 4], S_other[bx, ty + by * 4], V[bx, ty + by * 4, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 4, tx * 4:tx * 4 + 4]) T.writes(V[bx, ty + by * 4, tx * 4:tx * 4 + 4], S[bx, ty + by * 4]) s_val = T.alloc_buffer((1,), scope="local") s_other_val = T.alloc_buffer((1,), scope="local") s_max = T.alloc_buffer((1,), scope="local") scale = T.alloc_buffer((1,), scope="local") other_scale = T.alloc_buffer((1,), scope="local") v_vec = T.alloc_buffer((4,), "float16", scope="local") v_other_vec = T.alloc_buffer((4,), "float16", scope="local") s_val[0] = S[bx, ty + by * 4] s_other_val[0] = S_other[bx, ty + by * 4] s_max[0] = T.max(s_val[0], s_other_val[0]) s_val[0] = T.exp2(s_val[0] - s_max[0]) s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) for vec in T.vectorized(4): v_vec[vec] = V[bx, ty + by * 4, tx * 4 + vec] for vec in T.vectorized(4): v_other_vec[vec] = V_other[bx, ty + by * 4, tx * 4 + vec] for vec in range(4): v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0]) for vec in T.vectorized(4): V[bx, ty + by * 4, tx * 4 + vec] = v_vec[vec] S[bx, ty + by * 4] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] @T.prim_func(private=True) def reshape(var_add288: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) batch_size = T.int64() add288 = T.match_buffer(var_add288, (batch_size, T.int64(1), T.int64(6144)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(24), T.int64(256)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(24), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(add288[((v_ax2 * T.int64(256) + v_ax3) // T.int64(6144) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(256) + v_ax3) % T.int64(6144)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = add288[((v_ax2 * T.int64(256) + v_ax3) // T.int64(6144) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(256) + v_ax3) % T.int64(6144)] @T.prim_func(private=True) def reshape1(var_reshape192: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape192 = T.match_buffer(var_reshape192, (batch_size, T.int64(1), T.int64(24), T.int64(256)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(24), T.int64(256)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(batch_size, T.int64(24), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(reshape192[((v_ax2 // T.int64(256) + v_ax1) // T.int64(24) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(256) + v_ax1) % T.int64(24), v_ax2 % T.int64(256)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = reshape192[((v_ax2 // T.int64(256) + v_ax1) // T.int64(24) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(256) + v_ax1) % T.int64(24), v_ax2 % T.int64(256)] @T.prim_func(private=True) def reshape2(var_lv246: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) batch_size = T.int64() lv246 = T.match_buffer(var_lv246, (batch_size, T.int64(8), T.int64(256)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(8), T.int64(256)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(8), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(lv246[((v_ax3 // T.int64(256) + v_ax2) // T.int64(8) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(256) + v_ax2) % T.int64(8), v_ax3 % T.int64(256)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv246[((v_ax3 // T.int64(256) + v_ax2) // T.int64(8) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(256) + v_ax2) % T.int64(8), v_ax3 % T.int64(256)] @T.prim_func(private=True) def reshape3(var_reshape194: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) batch_size = T.int64() reshape194 = T.match_buffer(var_reshape194, (batch_size, T.int64(1), T.int64(8), T.int64(256)), "float16") T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(2048)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(reshape194[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(2048) // T.int64(256), v_ax2 % T.int64(256)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = reshape194[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(2048) // T.int64(256), v_ax2 % T.int64(256)] @T.prim_func(private=True) def reshape4(var_add192: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) seq_len = T.int64() add192 = T.match_buffer(var_add192, (T.int64(1), seq_len, T.int64(6144)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(24), T.int64(256)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(24), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(add192[T.int64(0), ((v_ax2 * T.int64(256) + v_ax3) // T.int64(6144) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(256) + v_ax3) % T.int64(6144)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = add192[T.int64(0), ((v_ax2 * T.int64(256) + v_ax3) // T.int64(6144) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(256) + v_ax3) % T.int64(6144)] @T.prim_func(private=True) def reshape5(var_reshape128: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) seq_len = T.int64() reshape128 = T.match_buffer(var_reshape128, (T.int64(1), seq_len, T.int64(24), T.int64(256)), "float16") T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(24), T.int64(256)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(seq_len, T.int64(24), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(reshape128[T.int64(0), ((v_ax2 // T.int64(256) + v_ax1) // T.int64(24) + v_ax0) % seq_len, (v_ax2 // T.int64(256) + v_ax1) % T.int64(24), v_ax2 % T.int64(256)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = reshape128[T.int64(0), ((v_ax2 // T.int64(256) + v_ax1) // T.int64(24) + v_ax0) % seq_len, (v_ax2 // T.int64(256) + v_ax1) % T.int64(24), v_ax2 % T.int64(256)] @T.prim_func(private=True) def reshape6(var_lv165: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) seq_len = T.int64() lv165 = T.match_buffer(var_lv165, (seq_len, T.int64(8), T.int64(256)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(8), T.int64(256)), "float16") # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(8), T.int64(256)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(lv165[((v_ax3 // T.int64(256) + v_ax2) // T.int64(8) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(256) + v_ax2) % T.int64(8), v_ax3 % T.int64(256)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv165[((v_ax3 // T.int64(256) + v_ax2) // T.int64(8) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(256) + v_ax2) % T.int64(8), v_ax3 % T.int64(256)] @T.prim_func(private=True) def reshape7(var_reshape130: T.handle, var_T_reshape: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) seq_len = T.int64() reshape130 = T.match_buffer(var_reshape130, (T.int64(1), seq_len, T.int64(8), T.int64(256)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(2048)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(reshape130[T.int64(0), (v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048) // T.int64(256), v_ax2 % T.int64(256)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = reshape130[T.int64(0), (v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048) // T.int64(256), v_ax2 % T.int64(256)] @T.prim_func def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True) src = T.match_buffer(var_src, (batch_size, n)) indices = T.match_buffer(var_indices, (batch_size,), "int32") m = T.int32(is_size_var=True) dst = T.match_buffer(var_dst, (m, n)) # with T.block("root"): for b, j in T.grid(batch_size, n): with T.block("scatter_2d"): vb, vj = T.axis.remap("SS", [b, j]) T.reads(src[vb, vj], indices[vb]) T.writes(dst[indices[vb], vj]) dst[indices[vb], vj] = src[vb, vj] @T.prim_func def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) temperature = T.match_buffer(var_temperature, (batch_size,)) num_chunks = T.int64(is_size_var=True) chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) softmax = T.match_buffer(var_softmax, (batch_size, vocab_size)) # with T.block("root"): temp_max_shared = T.alloc_buffer((batch_size,), scope="shared") temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared") for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): with T.block("max"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) T.reads(chunked_max[v0, v1]) T.writes(temp_max_shared[v0]) with T.init(): temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0) temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1]) for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): with T.block("sum_exp"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0]) T.writes(temp_sum_shared[v0]) with T.init(): temp_sum_shared[v0] = T.float32(0.0) temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1]) for l2_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for l2_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("log_pad"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks) v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(256) + l2_1 * T.int64(32) + l2_2) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0]) T.writes(softmax[v0, v1 * T.int64(4096) + v2]) if v1 * T.int64(4096) + v2 < vocab_size: softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0]) @T.prim_func(private=True) def take(var_layer_norm98: T.handle, var_logit_positions: T.handle, var_T_take: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm98 = T.match_buffer(var_layer_norm98, (T.int64(1), seq_len, T.int64(2048)), "float16") batch_size = T.int64() logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32") T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(2048)), "float16") # with T.block("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), batch_size, T.int64(2048)): with T.block("T_take"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(layer_norm98[v_ax0, logit_positions[v_ax1], v_ax2], logit_positions[v_ax1]) T.writes(T_take[v_ax0, v_ax1, v_ax2]) T_take[v_ax0, v_ax1, v_ax2] = layer_norm98[v_ax0, logit_positions[v_ax1], v_ax2] @T.prim_func(private=True) def take1(var_lv: T.handle, var_input_ids: T.handle, var_T_take: T.handle): T.func_attr({"op_pattern": 8, "tir.noalias": T.bool(True)}) vocab_size = T.int64() lv = T.match_buffer(var_lv, (vocab_size, T.int64(2048)), "float16") seq_len = T.int64() input_ids = T.match_buffer(var_input_ids, (seq_len,), "int32") T_take = T.match_buffer(var_T_take, (seq_len, T.int64(2048)), "float16") # with T.block("root"): for ax0, ax1 in T.grid(seq_len, T.int64(2048)): with T.block("T_take"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(lv[input_ids[v_ax0], v_ax1], input_ids[v_ax0]) T.writes(T_take[v_ax0, v_ax1]) T_take[v_ax0, v_ax1] = lv[input_ids[v_ax0], v_ax1] @T.prim_func def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) num_pages, page_size = T.int64(), T.int64(is_size_var=True) pages = T.match_buffer(var_pages, (num_pages, 2, 8, page_size, 256), "float16", offset_factor=1) seqlen = T.int64(is_size_var=True) position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1) k_data = T.match_buffer(var_k_data, (16, seqlen, 8, 256), "float16") v_data = T.match_buffer(var_v_data, (16, seqlen, 8, 256), "float16") # with T.block("root"): for p, h, d in T.grid(seqlen, 8, 256): with T.block("copy0"): vp, vh, vd = T.axis.remap("SSS", [p, h, d]) T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) position: T.int32 = position_map[vp] k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] @T.prim_func def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) num_pages = T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, 8, 16, 256), "float16", offset_factor=1) ntoken = T.int64(is_size_var=True) k_data = T.match_buffer(var_k_data, (ntoken, 8, 256), "float16") v_data = T.match_buffer(var_v_data, (ntoken, 8, 256), "float16") position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1) # with T.block("root"): for global_pos, h, f in T.grid(ntoken, 8, 256): if position_map[global_pos] != -1: with T.block("k_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf] @T.prim_func def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 8, 32, 256), "float16") page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (total_len, 8)) tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1) total_tree_order_len = T.int32(is_size_var=True) tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1) # with T.block("root"): assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention." for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((32, 256), "float16", scope="shared") V_smem = T.alloc_buffer((32, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 32), scope="shared") S_local = T.alloc_buffer((16, 32), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 32 + length_info[b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 31) // 32): L_kv_start: T.int32 = iterator_1 * 32 for lz_ly_fused_0 in range(16): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 32] page_offset: T.int32 = seq_offset % 32 K_smem[i, j] = pages[page_no, 0, by, page_offset, j] else: K_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(16): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 32] page_offset: T.int32 = seq_offset % 32 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:32, 0:256]) T.writes(S_local[0:16, 0:32]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 2 + li_1_init) j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) k = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(32): if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i]) T.writes(S_smem[row, 0:32]) for j in range(32): if row < 16: row_: T.int32 = LH_start + row if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(32): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:32], V_smem[0:32, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k = T.axis.reduce(32, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @R.function def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"): R.func_attr({"relax.memory_plan_dynamic_func_output": True}) gv: R.Tensor((2048, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global")) return gv @R.function def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), R.Object): batch_size = T.int64() vocab_size = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight4: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale4: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm99 = R.call_tir(cls.layer_norm, (input_embeds, gpt_neox_layers_0_input_layernorm_weight4, gpt_neox_layers_0_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv245 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight4, gpt_neox_layers_0_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv = R.call_tir(cls.fused_NT_matmul_add, (layer_norm99, lv245, gpt_neox_layers_0_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape192 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape193 = R.call_tir(cls.reshape1, (reshape192,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv246 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape193), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape194 = R.call_tir(cls.reshape2, (lv246,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape195 = R.call_tir(cls.reshape3, (reshape194,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv247 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight4, gpt_neox_layers_0_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm100 = R.call_tir(cls.layer_norm, (input_embeds, gpt_neox_layers_0_post_attention_layernorm_weight4, gpt_neox_layers_0_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv248 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv1 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm100, lv248, gpt_neox_layers_0_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv249 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv2 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv1, lv249, gpt_neox_layers_0_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv3 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape195, lv247, gpt_neox_layers_0_attention_dense_bias4, lv2, input_embeds), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm101 = R.call_tir(cls.layer_norm, (lv3, gpt_neox_layers_1_input_layernorm_weight4, gpt_neox_layers_1_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv250 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight4, gpt_neox_layers_1_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv4 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm101, lv250, gpt_neox_layers_1_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape196 = R.call_tir(cls.reshape, (lv4,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape197 = R.call_tir(cls.reshape1, (reshape196,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv251 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape197), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape198 = R.call_tir(cls.reshape2, (lv251,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape199 = R.call_tir(cls.reshape3, (reshape198,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv252 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight4, gpt_neox_layers_1_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm102 = R.call_tir(cls.layer_norm, (lv3, gpt_neox_layers_1_post_attention_layernorm_weight4, gpt_neox_layers_1_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv253 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv5 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm102, lv253, gpt_neox_layers_1_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv254 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv6 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv5, lv254, gpt_neox_layers_1_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv7 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape199, lv252, gpt_neox_layers_1_attention_dense_bias4, lv6, lv3), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm103 = R.call_tir(cls.layer_norm, (lv7, gpt_neox_layers_2_input_layernorm_weight4, gpt_neox_layers_2_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv255 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight4, gpt_neox_layers_2_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv8 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm103, lv255, gpt_neox_layers_2_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape200 = R.call_tir(cls.reshape, (lv8,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape201 = R.call_tir(cls.reshape1, (reshape200,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv256 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape201), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape202 = R.call_tir(cls.reshape2, (lv256,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape203 = R.call_tir(cls.reshape3, (reshape202,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv257 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight4, gpt_neox_layers_2_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm104 = R.call_tir(cls.layer_norm, (lv7, gpt_neox_layers_2_post_attention_layernorm_weight4, gpt_neox_layers_2_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv258 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv9 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm104, lv258, gpt_neox_layers_2_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv259 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv10 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv9, lv259, gpt_neox_layers_2_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv11 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape203, lv257, gpt_neox_layers_2_attention_dense_bias4, lv10, lv7), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm105 = R.call_tir(cls.layer_norm, (lv11, gpt_neox_layers_3_input_layernorm_weight4, gpt_neox_layers_3_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv260 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight4, gpt_neox_layers_3_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv12 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm105, lv260, gpt_neox_layers_3_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape204 = R.call_tir(cls.reshape, (lv12,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape205 = R.call_tir(cls.reshape1, (reshape204,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv261 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape205), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape206 = R.call_tir(cls.reshape2, (lv261,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape207 = R.call_tir(cls.reshape3, (reshape206,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv262 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight4, gpt_neox_layers_3_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm106 = R.call_tir(cls.layer_norm, (lv11, gpt_neox_layers_3_post_attention_layernorm_weight4, gpt_neox_layers_3_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv263 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv13 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm106, lv263, gpt_neox_layers_3_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv264 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv14 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv13, lv264, gpt_neox_layers_3_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv15 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape207, lv262, gpt_neox_layers_3_attention_dense_bias4, lv14, lv11), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm107 = R.call_tir(cls.layer_norm, (lv15, gpt_neox_layers_4_input_layernorm_weight4, gpt_neox_layers_4_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv265 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight4, gpt_neox_layers_4_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv16 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm107, lv265, gpt_neox_layers_4_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape208 = R.call_tir(cls.reshape, (lv16,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape209 = R.call_tir(cls.reshape1, (reshape208,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv266 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape209), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape210 = R.call_tir(cls.reshape2, (lv266,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape211 = R.call_tir(cls.reshape3, (reshape210,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv267 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight4, gpt_neox_layers_4_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm108 = R.call_tir(cls.layer_norm, (lv15, gpt_neox_layers_4_post_attention_layernorm_weight4, gpt_neox_layers_4_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv268 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv17 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm108, lv268, gpt_neox_layers_4_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv269 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv18 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv17, lv269, gpt_neox_layers_4_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv19 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape211, lv267, gpt_neox_layers_4_attention_dense_bias4, lv18, lv15), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm109 = R.call_tir(cls.layer_norm, (lv19, gpt_neox_layers_5_input_layernorm_weight4, gpt_neox_layers_5_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv270 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight4, gpt_neox_layers_5_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv20 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm109, lv270, gpt_neox_layers_5_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape212 = R.call_tir(cls.reshape, (lv20,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape213 = R.call_tir(cls.reshape1, (reshape212,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv271 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape213), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape214 = R.call_tir(cls.reshape2, (lv271,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape215 = R.call_tir(cls.reshape3, (reshape214,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv272 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight4, gpt_neox_layers_5_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm110 = R.call_tir(cls.layer_norm, (lv19, gpt_neox_layers_5_post_attention_layernorm_weight4, gpt_neox_layers_5_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv273 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv21 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm110, lv273, gpt_neox_layers_5_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv274 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv22 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv21, lv274, gpt_neox_layers_5_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv23 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape215, lv272, gpt_neox_layers_5_attention_dense_bias4, lv22, lv19), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm111 = R.call_tir(cls.layer_norm, (lv23, gpt_neox_layers_6_input_layernorm_weight4, gpt_neox_layers_6_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv275 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight4, gpt_neox_layers_6_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv24 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm111, lv275, gpt_neox_layers_6_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape216 = R.call_tir(cls.reshape, (lv24,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape217 = R.call_tir(cls.reshape1, (reshape216,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv276 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape217), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape218 = R.call_tir(cls.reshape2, (lv276,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape219 = R.call_tir(cls.reshape3, (reshape218,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv277 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight4, gpt_neox_layers_6_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm112 = R.call_tir(cls.layer_norm, (lv23, gpt_neox_layers_6_post_attention_layernorm_weight4, gpt_neox_layers_6_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv278 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv25 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm112, lv278, gpt_neox_layers_6_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv279 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv26 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv25, lv279, gpt_neox_layers_6_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv27 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape219, lv277, gpt_neox_layers_6_attention_dense_bias4, lv26, lv23), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm113 = R.call_tir(cls.layer_norm, (lv27, gpt_neox_layers_7_input_layernorm_weight4, gpt_neox_layers_7_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv280 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight4, gpt_neox_layers_7_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv28 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm113, lv280, gpt_neox_layers_7_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape220 = R.call_tir(cls.reshape, (lv28,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape221 = R.call_tir(cls.reshape1, (reshape220,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv281 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape221), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape222 = R.call_tir(cls.reshape2, (lv281,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape223 = R.call_tir(cls.reshape3, (reshape222,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv282 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight4, gpt_neox_layers_7_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm114 = R.call_tir(cls.layer_norm, (lv27, gpt_neox_layers_7_post_attention_layernorm_weight4, gpt_neox_layers_7_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv283 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv29 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm114, lv283, gpt_neox_layers_7_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv284 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv30 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv29, lv284, gpt_neox_layers_7_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv31 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape223, lv282, gpt_neox_layers_7_attention_dense_bias4, lv30, lv27), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm115 = R.call_tir(cls.layer_norm, (lv31, gpt_neox_layers_8_input_layernorm_weight4, gpt_neox_layers_8_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv285 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight4, gpt_neox_layers_8_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv32 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm115, lv285, gpt_neox_layers_8_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape224 = R.call_tir(cls.reshape, (lv32,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape225 = R.call_tir(cls.reshape1, (reshape224,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv286 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape225), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape226 = R.call_tir(cls.reshape2, (lv286,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape227 = R.call_tir(cls.reshape3, (reshape226,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv287 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight4, gpt_neox_layers_8_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm116 = R.call_tir(cls.layer_norm, (lv31, gpt_neox_layers_8_post_attention_layernorm_weight4, gpt_neox_layers_8_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv288 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv33 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm116, lv288, gpt_neox_layers_8_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv289 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv34 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv33, lv289, gpt_neox_layers_8_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv35 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape227, lv287, gpt_neox_layers_8_attention_dense_bias4, lv34, lv31), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm117 = R.call_tir(cls.layer_norm, (lv35, gpt_neox_layers_9_input_layernorm_weight4, gpt_neox_layers_9_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv290 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight4, gpt_neox_layers_9_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv36 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm117, lv290, gpt_neox_layers_9_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape228 = R.call_tir(cls.reshape, (lv36,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape229 = R.call_tir(cls.reshape1, (reshape228,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv291 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape229), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape230 = R.call_tir(cls.reshape2, (lv291,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape231 = R.call_tir(cls.reshape3, (reshape230,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv292 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight4, gpt_neox_layers_9_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm118 = R.call_tir(cls.layer_norm, (lv35, gpt_neox_layers_9_post_attention_layernorm_weight4, gpt_neox_layers_9_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv293 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv37 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm118, lv293, gpt_neox_layers_9_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv294 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv38 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv37, lv294, gpt_neox_layers_9_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv39 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape231, lv292, gpt_neox_layers_9_attention_dense_bias4, lv38, lv35), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm119 = R.call_tir(cls.layer_norm, (lv39, gpt_neox_layers_10_input_layernorm_weight4, gpt_neox_layers_10_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv295 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight4, gpt_neox_layers_10_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv40 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm119, lv295, gpt_neox_layers_10_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape232 = R.call_tir(cls.reshape, (lv40,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape233 = R.call_tir(cls.reshape1, (reshape232,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv296 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape233), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape234 = R.call_tir(cls.reshape2, (lv296,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape235 = R.call_tir(cls.reshape3, (reshape234,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv297 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight4, gpt_neox_layers_10_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm120 = R.call_tir(cls.layer_norm, (lv39, gpt_neox_layers_10_post_attention_layernorm_weight4, gpt_neox_layers_10_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv298 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv41 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm120, lv298, gpt_neox_layers_10_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv299 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv42 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv41, lv299, gpt_neox_layers_10_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv43 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape235, lv297, gpt_neox_layers_10_attention_dense_bias4, lv42, lv39), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm121 = R.call_tir(cls.layer_norm, (lv43, gpt_neox_layers_11_input_layernorm_weight4, gpt_neox_layers_11_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv300 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight4, gpt_neox_layers_11_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv44 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm121, lv300, gpt_neox_layers_11_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape236 = R.call_tir(cls.reshape, (lv44,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape237 = R.call_tir(cls.reshape1, (reshape236,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv301 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape237), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape238 = R.call_tir(cls.reshape2, (lv301,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape239 = R.call_tir(cls.reshape3, (reshape238,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv302 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight4, gpt_neox_layers_11_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm122 = R.call_tir(cls.layer_norm, (lv43, gpt_neox_layers_11_post_attention_layernorm_weight4, gpt_neox_layers_11_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv303 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv45 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm122, lv303, gpt_neox_layers_11_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv304 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv46 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv45, lv304, gpt_neox_layers_11_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv47 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape239, lv302, gpt_neox_layers_11_attention_dense_bias4, lv46, lv43), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm123 = R.call_tir(cls.layer_norm, (lv47, gpt_neox_layers_12_input_layernorm_weight4, gpt_neox_layers_12_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv305 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight4, gpt_neox_layers_12_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv48 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm123, lv305, gpt_neox_layers_12_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape240 = R.call_tir(cls.reshape, (lv48,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape241 = R.call_tir(cls.reshape1, (reshape240,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv306 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape241), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape242 = R.call_tir(cls.reshape2, (lv306,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape243 = R.call_tir(cls.reshape3, (reshape242,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv307 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight4, gpt_neox_layers_12_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm124 = R.call_tir(cls.layer_norm, (lv47, gpt_neox_layers_12_post_attention_layernorm_weight4, gpt_neox_layers_12_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv308 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv49 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm124, lv308, gpt_neox_layers_12_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv309 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv50 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv49, lv309, gpt_neox_layers_12_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv51 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape243, lv307, gpt_neox_layers_12_attention_dense_bias4, lv50, lv47), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm125 = R.call_tir(cls.layer_norm, (lv51, gpt_neox_layers_13_input_layernorm_weight4, gpt_neox_layers_13_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv310 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight4, gpt_neox_layers_13_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv52 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm125, lv310, gpt_neox_layers_13_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape244 = R.call_tir(cls.reshape, (lv52,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape245 = R.call_tir(cls.reshape1, (reshape244,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv311 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape245), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape246 = R.call_tir(cls.reshape2, (lv311,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape247 = R.call_tir(cls.reshape3, (reshape246,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv312 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight4, gpt_neox_layers_13_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm126 = R.call_tir(cls.layer_norm, (lv51, gpt_neox_layers_13_post_attention_layernorm_weight4, gpt_neox_layers_13_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv313 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv53 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm126, lv313, gpt_neox_layers_13_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv314 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv54 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv53, lv314, gpt_neox_layers_13_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv55 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape247, lv312, gpt_neox_layers_13_attention_dense_bias4, lv54, lv51), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm127 = R.call_tir(cls.layer_norm, (lv55, gpt_neox_layers_14_input_layernorm_weight4, gpt_neox_layers_14_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv315 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight4, gpt_neox_layers_14_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv56 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm127, lv315, gpt_neox_layers_14_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape248 = R.call_tir(cls.reshape, (lv56,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape249 = R.call_tir(cls.reshape1, (reshape248,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv316 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape249), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape250 = R.call_tir(cls.reshape2, (lv316,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape251 = R.call_tir(cls.reshape3, (reshape250,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv317 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight4, gpt_neox_layers_14_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm128 = R.call_tir(cls.layer_norm, (lv55, gpt_neox_layers_14_post_attention_layernorm_weight4, gpt_neox_layers_14_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv318 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv57 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm128, lv318, gpt_neox_layers_14_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv319 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv58 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv57, lv319, gpt_neox_layers_14_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv59 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape251, lv317, gpt_neox_layers_14_attention_dense_bias4, lv58, lv55), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm129 = R.call_tir(cls.layer_norm, (lv59, gpt_neox_layers_15_input_layernorm_weight4, gpt_neox_layers_15_input_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv320 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight4, gpt_neox_layers_15_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv60 = R.call_tir(cls.fused_NT_matmul_add, (layer_norm129, lv320, gpt_neox_layers_15_attention_query_key_value_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16")) reshape252 = R.call_tir(cls.reshape, (lv60,), out_sinfo=R.Tensor((batch_size, 1, 24, 256), dtype="float16")) reshape253 = R.call_tir(cls.reshape1, (reshape252,), out_sinfo=R.Tensor((batch_size, 24, 256), dtype="float16")) lv321 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape253), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape254 = R.call_tir(cls.reshape2, (lv321,), out_sinfo=R.Tensor((batch_size, 1, 8, 256), dtype="float16")) reshape255 = R.call_tir(cls.reshape3, (reshape254,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv322 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight4, gpt_neox_layers_15_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm130 = R.call_tir(cls.layer_norm, (lv59, gpt_neox_layers_15_post_attention_layernorm_weight4, gpt_neox_layers_15_post_attention_layernorm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv323 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv61 = R.call_tir(cls.fused_NT_matmul2_add2_gelu_cast, (layer_norm130, lv323, gpt_neox_layers_15_mlp_dense_h_to_4h_bias4), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16")) lv324 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv62 = R.call_tir(cls.fused_NT_matmul3_add3_cast1, (lv61, lv324, gpt_neox_layers_15_mlp_dense_4h_to_h_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv63 = R.call_tir(cls.fused_NT_matmul1_add1_add4_add4, (reshape255, lv322, gpt_neox_layers_15_attention_dense_bias4, lv62, lv59), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) layer_norm131 = R.call_tir(cls.layer_norm, (lv63, gpt_neox_final_layer_norm_weight4, gpt_neox_final_layer_norm_bias4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16")) lv325 = R.call_tir(cls.dequantize, (embed_out_q_weight4, embed_out_q_scale4), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) lv64 = R.call_tir(cls.fused_NT_matmul4_cast2, (layer_norm131, lv325), out_sinfo=R.Tensor((batch_size, 1, vocab_size), dtype="float32")) gv4: R.Tuple(R.Tensor((batch_size, 1, vocab_size), dtype="float32"), R.Object) = lv64, paged_kv_cache R.output(gv4) return gv4 @R.function def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", "vocab_size"), dtype="float32"), R.Object): batch_size = T.int64() vocab_size = T.int64() seq_len = T.int64() R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight3: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale3: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm66 = R.call_tir(cls.layer_norm1, (input_embeds, gpt_neox_layers_0_input_layernorm_weight3, gpt_neox_layers_0_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv164 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight3, gpt_neox_layers_0_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv65 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm66, lv164, gpt_neox_layers_0_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape128 = R.call_tir(cls.reshape4, (lv65,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape129 = R.call_tir(cls.reshape5, (reshape128,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv165 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape129), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape130 = R.call_tir(cls.reshape6, (lv165,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape131 = R.call_tir(cls.reshape7, (reshape130,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv166 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight3, gpt_neox_layers_0_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm67 = R.call_tir(cls.layer_norm1, (input_embeds, gpt_neox_layers_0_post_attention_layernorm_weight3, gpt_neox_layers_0_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv167 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv66 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm67, lv167, gpt_neox_layers_0_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv168 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv67 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv66, lv168, gpt_neox_layers_0_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv68 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape131, lv166, gpt_neox_layers_0_attention_dense_bias3, lv67, input_embeds), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm68 = R.call_tir(cls.layer_norm1, (lv68, gpt_neox_layers_1_input_layernorm_weight3, gpt_neox_layers_1_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv169 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight3, gpt_neox_layers_1_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv69 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm68, lv169, gpt_neox_layers_1_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape132 = R.call_tir(cls.reshape4, (lv69,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape133 = R.call_tir(cls.reshape5, (reshape132,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv170 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape133), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape134 = R.call_tir(cls.reshape6, (lv170,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape135 = R.call_tir(cls.reshape7, (reshape134,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv171 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight3, gpt_neox_layers_1_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm69 = R.call_tir(cls.layer_norm1, (lv68, gpt_neox_layers_1_post_attention_layernorm_weight3, gpt_neox_layers_1_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv172 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv70 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm69, lv172, gpt_neox_layers_1_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv173 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv71 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv70, lv173, gpt_neox_layers_1_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv72 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape135, lv171, gpt_neox_layers_1_attention_dense_bias3, lv71, lv68), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm70 = R.call_tir(cls.layer_norm1, (lv72, gpt_neox_layers_2_input_layernorm_weight3, gpt_neox_layers_2_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv174 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight3, gpt_neox_layers_2_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv73 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm70, lv174, gpt_neox_layers_2_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape136 = R.call_tir(cls.reshape4, (lv73,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape137 = R.call_tir(cls.reshape5, (reshape136,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv175 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape137), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape138 = R.call_tir(cls.reshape6, (lv175,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape139 = R.call_tir(cls.reshape7, (reshape138,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv176 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight3, gpt_neox_layers_2_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm71 = R.call_tir(cls.layer_norm1, (lv72, gpt_neox_layers_2_post_attention_layernorm_weight3, gpt_neox_layers_2_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv177 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv74 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm71, lv177, gpt_neox_layers_2_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv178 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv75 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv74, lv178, gpt_neox_layers_2_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv76 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape139, lv176, gpt_neox_layers_2_attention_dense_bias3, lv75, lv72), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm72 = R.call_tir(cls.layer_norm1, (lv76, gpt_neox_layers_3_input_layernorm_weight3, gpt_neox_layers_3_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv179 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight3, gpt_neox_layers_3_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv77 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm72, lv179, gpt_neox_layers_3_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape140 = R.call_tir(cls.reshape4, (lv77,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape141 = R.call_tir(cls.reshape5, (reshape140,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv180 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape141), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape142 = R.call_tir(cls.reshape6, (lv180,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape143 = R.call_tir(cls.reshape7, (reshape142,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv181 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight3, gpt_neox_layers_3_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm73 = R.call_tir(cls.layer_norm1, (lv76, gpt_neox_layers_3_post_attention_layernorm_weight3, gpt_neox_layers_3_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv182 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv78 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm73, lv182, gpt_neox_layers_3_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv183 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv79 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv78, lv183, gpt_neox_layers_3_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv80 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape143, lv181, gpt_neox_layers_3_attention_dense_bias3, lv79, lv76), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm74 = R.call_tir(cls.layer_norm1, (lv80, gpt_neox_layers_4_input_layernorm_weight3, gpt_neox_layers_4_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv184 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight3, gpt_neox_layers_4_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv81 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm74, lv184, gpt_neox_layers_4_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape144 = R.call_tir(cls.reshape4, (lv81,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape145 = R.call_tir(cls.reshape5, (reshape144,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv185 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape145), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape146 = R.call_tir(cls.reshape6, (lv185,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape147 = R.call_tir(cls.reshape7, (reshape146,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv186 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight3, gpt_neox_layers_4_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm75 = R.call_tir(cls.layer_norm1, (lv80, gpt_neox_layers_4_post_attention_layernorm_weight3, gpt_neox_layers_4_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv187 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv82 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm75, lv187, gpt_neox_layers_4_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv188 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv83 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv82, lv188, gpt_neox_layers_4_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv84 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape147, lv186, gpt_neox_layers_4_attention_dense_bias3, lv83, lv80), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm76 = R.call_tir(cls.layer_norm1, (lv84, gpt_neox_layers_5_input_layernorm_weight3, gpt_neox_layers_5_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv189 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight3, gpt_neox_layers_5_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv85 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm76, lv189, gpt_neox_layers_5_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape148 = R.call_tir(cls.reshape4, (lv85,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape149 = R.call_tir(cls.reshape5, (reshape148,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv190 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape149), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape150 = R.call_tir(cls.reshape6, (lv190,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape151 = R.call_tir(cls.reshape7, (reshape150,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv191 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight3, gpt_neox_layers_5_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm77 = R.call_tir(cls.layer_norm1, (lv84, gpt_neox_layers_5_post_attention_layernorm_weight3, gpt_neox_layers_5_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv192 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv86 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm77, lv192, gpt_neox_layers_5_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv193 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv87 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv86, lv193, gpt_neox_layers_5_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv88 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape151, lv191, gpt_neox_layers_5_attention_dense_bias3, lv87, lv84), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm78 = R.call_tir(cls.layer_norm1, (lv88, gpt_neox_layers_6_input_layernorm_weight3, gpt_neox_layers_6_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv194 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight3, gpt_neox_layers_6_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv89 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm78, lv194, gpt_neox_layers_6_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape152 = R.call_tir(cls.reshape4, (lv89,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape153 = R.call_tir(cls.reshape5, (reshape152,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv195 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape153), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape154 = R.call_tir(cls.reshape6, (lv195,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape155 = R.call_tir(cls.reshape7, (reshape154,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv196 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight3, gpt_neox_layers_6_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm79 = R.call_tir(cls.layer_norm1, (lv88, gpt_neox_layers_6_post_attention_layernorm_weight3, gpt_neox_layers_6_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv197 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv90 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm79, lv197, gpt_neox_layers_6_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv198 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv91 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv90, lv198, gpt_neox_layers_6_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv92 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape155, lv196, gpt_neox_layers_6_attention_dense_bias3, lv91, lv88), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm80 = R.call_tir(cls.layer_norm1, (lv92, gpt_neox_layers_7_input_layernorm_weight3, gpt_neox_layers_7_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv199 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight3, gpt_neox_layers_7_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv93 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm80, lv199, gpt_neox_layers_7_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape156 = R.call_tir(cls.reshape4, (lv93,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape157 = R.call_tir(cls.reshape5, (reshape156,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv200 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape157), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape158 = R.call_tir(cls.reshape6, (lv200,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape159 = R.call_tir(cls.reshape7, (reshape158,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv201 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight3, gpt_neox_layers_7_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm81 = R.call_tir(cls.layer_norm1, (lv92, gpt_neox_layers_7_post_attention_layernorm_weight3, gpt_neox_layers_7_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv202 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv94 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm81, lv202, gpt_neox_layers_7_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv203 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv95 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv94, lv203, gpt_neox_layers_7_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv96 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape159, lv201, gpt_neox_layers_7_attention_dense_bias3, lv95, lv92), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm82 = R.call_tir(cls.layer_norm1, (lv96, gpt_neox_layers_8_input_layernorm_weight3, gpt_neox_layers_8_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv204 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight3, gpt_neox_layers_8_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv97 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm82, lv204, gpt_neox_layers_8_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape160 = R.call_tir(cls.reshape4, (lv97,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape161 = R.call_tir(cls.reshape5, (reshape160,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv205 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape161), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape162 = R.call_tir(cls.reshape6, (lv205,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape163 = R.call_tir(cls.reshape7, (reshape162,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv206 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight3, gpt_neox_layers_8_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm83 = R.call_tir(cls.layer_norm1, (lv96, gpt_neox_layers_8_post_attention_layernorm_weight3, gpt_neox_layers_8_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv207 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv98 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm83, lv207, gpt_neox_layers_8_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv208 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv99 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv98, lv208, gpt_neox_layers_8_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv100 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape163, lv206, gpt_neox_layers_8_attention_dense_bias3, lv99, lv96), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm84 = R.call_tir(cls.layer_norm1, (lv100, gpt_neox_layers_9_input_layernorm_weight3, gpt_neox_layers_9_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv209 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight3, gpt_neox_layers_9_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv101 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm84, lv209, gpt_neox_layers_9_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape164 = R.call_tir(cls.reshape4, (lv101,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape165 = R.call_tir(cls.reshape5, (reshape164,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv210 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape165), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape166 = R.call_tir(cls.reshape6, (lv210,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape167 = R.call_tir(cls.reshape7, (reshape166,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv211 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight3, gpt_neox_layers_9_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm85 = R.call_tir(cls.layer_norm1, (lv100, gpt_neox_layers_9_post_attention_layernorm_weight3, gpt_neox_layers_9_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv212 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv102 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm85, lv212, gpt_neox_layers_9_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv213 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv103 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv102, lv213, gpt_neox_layers_9_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv104 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape167, lv211, gpt_neox_layers_9_attention_dense_bias3, lv103, lv100), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm86 = R.call_tir(cls.layer_norm1, (lv104, gpt_neox_layers_10_input_layernorm_weight3, gpt_neox_layers_10_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv214 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight3, gpt_neox_layers_10_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv105 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm86, lv214, gpt_neox_layers_10_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape168 = R.call_tir(cls.reshape4, (lv105,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape169 = R.call_tir(cls.reshape5, (reshape168,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv215 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape169), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape170 = R.call_tir(cls.reshape6, (lv215,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape171 = R.call_tir(cls.reshape7, (reshape170,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv216 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight3, gpt_neox_layers_10_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm87 = R.call_tir(cls.layer_norm1, (lv104, gpt_neox_layers_10_post_attention_layernorm_weight3, gpt_neox_layers_10_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv217 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv106 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm87, lv217, gpt_neox_layers_10_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv218 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv107 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv106, lv218, gpt_neox_layers_10_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv108 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape171, lv216, gpt_neox_layers_10_attention_dense_bias3, lv107, lv104), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm88 = R.call_tir(cls.layer_norm1, (lv108, gpt_neox_layers_11_input_layernorm_weight3, gpt_neox_layers_11_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv219 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight3, gpt_neox_layers_11_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv109 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm88, lv219, gpt_neox_layers_11_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape172 = R.call_tir(cls.reshape4, (lv109,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape173 = R.call_tir(cls.reshape5, (reshape172,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv220 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape173), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape174 = R.call_tir(cls.reshape6, (lv220,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape175 = R.call_tir(cls.reshape7, (reshape174,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv221 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight3, gpt_neox_layers_11_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm89 = R.call_tir(cls.layer_norm1, (lv108, gpt_neox_layers_11_post_attention_layernorm_weight3, gpt_neox_layers_11_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv222 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv110 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm89, lv222, gpt_neox_layers_11_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv223 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv111 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv110, lv223, gpt_neox_layers_11_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv112 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape175, lv221, gpt_neox_layers_11_attention_dense_bias3, lv111, lv108), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm90 = R.call_tir(cls.layer_norm1, (lv112, gpt_neox_layers_12_input_layernorm_weight3, gpt_neox_layers_12_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv224 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight3, gpt_neox_layers_12_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv113 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm90, lv224, gpt_neox_layers_12_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape176 = R.call_tir(cls.reshape4, (lv113,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape177 = R.call_tir(cls.reshape5, (reshape176,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv225 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape177), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape178 = R.call_tir(cls.reshape6, (lv225,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape179 = R.call_tir(cls.reshape7, (reshape178,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv226 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight3, gpt_neox_layers_12_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm91 = R.call_tir(cls.layer_norm1, (lv112, gpt_neox_layers_12_post_attention_layernorm_weight3, gpt_neox_layers_12_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv227 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv114 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm91, lv227, gpt_neox_layers_12_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv228 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv115 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv114, lv228, gpt_neox_layers_12_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv116 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape179, lv226, gpt_neox_layers_12_attention_dense_bias3, lv115, lv112), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm92 = R.call_tir(cls.layer_norm1, (lv116, gpt_neox_layers_13_input_layernorm_weight3, gpt_neox_layers_13_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv229 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight3, gpt_neox_layers_13_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv117 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm92, lv229, gpt_neox_layers_13_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape180 = R.call_tir(cls.reshape4, (lv117,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape181 = R.call_tir(cls.reshape5, (reshape180,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv230 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape181), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape182 = R.call_tir(cls.reshape6, (lv230,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape183 = R.call_tir(cls.reshape7, (reshape182,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv231 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight3, gpt_neox_layers_13_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm93 = R.call_tir(cls.layer_norm1, (lv116, gpt_neox_layers_13_post_attention_layernorm_weight3, gpt_neox_layers_13_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv232 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv118 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm93, lv232, gpt_neox_layers_13_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv233 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv119 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv118, lv233, gpt_neox_layers_13_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv120 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape183, lv231, gpt_neox_layers_13_attention_dense_bias3, lv119, lv116), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm94 = R.call_tir(cls.layer_norm1, (lv120, gpt_neox_layers_14_input_layernorm_weight3, gpt_neox_layers_14_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv234 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight3, gpt_neox_layers_14_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv121 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm94, lv234, gpt_neox_layers_14_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape184 = R.call_tir(cls.reshape4, (lv121,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape185 = R.call_tir(cls.reshape5, (reshape184,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv235 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape185), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape186 = R.call_tir(cls.reshape6, (lv235,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape187 = R.call_tir(cls.reshape7, (reshape186,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv236 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight3, gpt_neox_layers_14_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm95 = R.call_tir(cls.layer_norm1, (lv120, gpt_neox_layers_14_post_attention_layernorm_weight3, gpt_neox_layers_14_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv237 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv122 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm95, lv237, gpt_neox_layers_14_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv238 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv123 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv122, lv238, gpt_neox_layers_14_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv124 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape187, lv236, gpt_neox_layers_14_attention_dense_bias3, lv123, lv120), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm96 = R.call_tir(cls.layer_norm1, (lv124, gpt_neox_layers_15_input_layernorm_weight3, gpt_neox_layers_15_input_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv239 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight3, gpt_neox_layers_15_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv125 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm96, lv239, gpt_neox_layers_15_attention_query_key_value_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape188 = R.call_tir(cls.reshape4, (lv125,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape189 = R.call_tir(cls.reshape5, (reshape188,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv240 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape189), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape190 = R.call_tir(cls.reshape6, (lv240,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape191 = R.call_tir(cls.reshape7, (reshape190,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv241 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight3, gpt_neox_layers_15_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm97 = R.call_tir(cls.layer_norm1, (lv124, gpt_neox_layers_15_post_attention_layernorm_weight3, gpt_neox_layers_15_post_attention_layernorm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv242 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv126 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm97, lv242, gpt_neox_layers_15_mlp_dense_h_to_4h_bias3), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv243 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv127 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv126, lv243, gpt_neox_layers_15_mlp_dense_4h_to_h_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv128 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape191, lv241, gpt_neox_layers_15_attention_dense_bias3, lv127, lv124), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm98 = R.call_tir(cls.layer_norm1, (lv128, gpt_neox_final_layer_norm_weight3, gpt_neox_final_layer_norm_bias3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) take1 = R.call_tir(cls.take, (layer_norm98, logit_positions), out_sinfo=R.Tensor((1, batch_size, 2048), dtype="float16")) lv244 = R.call_tir(cls.dequantize, (embed_out_q_weight3, embed_out_q_scale3), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) lv129 = R.call_tir(cls.fused_NT_matmul9_cast5, (take1, lv244), out_sinfo=R.Tensor((1, batch_size, vocab_size), dtype="float32")) gv3: R.Tuple(R.Tensor((1, batch_size, vocab_size), dtype="float32"), R.Object) = lv129, paged_kv_cache R.output(gv3) return gv3 @R.function def batch_verify(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", "vocab_size"), dtype="float32"), R.Object): seq_len = T.int64() vocab_size = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_layers_0_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight5: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale5: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm132 = R.call_tir(cls.layer_norm1, (input_embeds, gpt_neox_layers_0_input_layernorm_weight5, gpt_neox_layers_0_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv326 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight5, gpt_neox_layers_0_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv130 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm132, lv326, gpt_neox_layers_0_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape256 = R.call_tir(cls.reshape4, (lv130,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape257 = R.call_tir(cls.reshape5, (reshape256,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv327 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape257), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape258 = R.call_tir(cls.reshape6, (lv327,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape259 = R.call_tir(cls.reshape7, (reshape258,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv328 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight5, gpt_neox_layers_0_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm133 = R.call_tir(cls.layer_norm1, (input_embeds, gpt_neox_layers_0_post_attention_layernorm_weight5, gpt_neox_layers_0_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv329 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv131 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm133, lv329, gpt_neox_layers_0_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv330 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv132 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv131, lv330, gpt_neox_layers_0_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv133 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape259, lv328, gpt_neox_layers_0_attention_dense_bias5, lv132, input_embeds), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm134 = R.call_tir(cls.layer_norm1, (lv133, gpt_neox_layers_1_input_layernorm_weight5, gpt_neox_layers_1_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv331 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight5, gpt_neox_layers_1_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv134 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm134, lv331, gpt_neox_layers_1_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape260 = R.call_tir(cls.reshape4, (lv134,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape261 = R.call_tir(cls.reshape5, (reshape260,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv332 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape261), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape262 = R.call_tir(cls.reshape6, (lv332,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape263 = R.call_tir(cls.reshape7, (reshape262,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv333 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight5, gpt_neox_layers_1_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm135 = R.call_tir(cls.layer_norm1, (lv133, gpt_neox_layers_1_post_attention_layernorm_weight5, gpt_neox_layers_1_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv334 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv135 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm135, lv334, gpt_neox_layers_1_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv335 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv136 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv135, lv335, gpt_neox_layers_1_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv137 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape263, lv333, gpt_neox_layers_1_attention_dense_bias5, lv136, lv133), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm136 = R.call_tir(cls.layer_norm1, (lv137, gpt_neox_layers_2_input_layernorm_weight5, gpt_neox_layers_2_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv336 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight5, gpt_neox_layers_2_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv138 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm136, lv336, gpt_neox_layers_2_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape264 = R.call_tir(cls.reshape4, (lv138,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape265 = R.call_tir(cls.reshape5, (reshape264,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv337 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape265), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape266 = R.call_tir(cls.reshape6, (lv337,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape267 = R.call_tir(cls.reshape7, (reshape266,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv338 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight5, gpt_neox_layers_2_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm137 = R.call_tir(cls.layer_norm1, (lv137, gpt_neox_layers_2_post_attention_layernorm_weight5, gpt_neox_layers_2_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv339 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv139 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm137, lv339, gpt_neox_layers_2_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv340 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv140 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv139, lv340, gpt_neox_layers_2_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv141 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape267, lv338, gpt_neox_layers_2_attention_dense_bias5, lv140, lv137), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm138 = R.call_tir(cls.layer_norm1, (lv141, gpt_neox_layers_3_input_layernorm_weight5, gpt_neox_layers_3_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv341 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight5, gpt_neox_layers_3_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv142 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm138, lv341, gpt_neox_layers_3_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape268 = R.call_tir(cls.reshape4, (lv142,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape269 = R.call_tir(cls.reshape5, (reshape268,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv342 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape269), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape270 = R.call_tir(cls.reshape6, (lv342,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape271 = R.call_tir(cls.reshape7, (reshape270,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv343 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight5, gpt_neox_layers_3_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm139 = R.call_tir(cls.layer_norm1, (lv141, gpt_neox_layers_3_post_attention_layernorm_weight5, gpt_neox_layers_3_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv344 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv143 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm139, lv344, gpt_neox_layers_3_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv345 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv144 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv143, lv345, gpt_neox_layers_3_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv145 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape271, lv343, gpt_neox_layers_3_attention_dense_bias5, lv144, lv141), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm140 = R.call_tir(cls.layer_norm1, (lv145, gpt_neox_layers_4_input_layernorm_weight5, gpt_neox_layers_4_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv346 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight5, gpt_neox_layers_4_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv146 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm140, lv346, gpt_neox_layers_4_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape272 = R.call_tir(cls.reshape4, (lv146,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape273 = R.call_tir(cls.reshape5, (reshape272,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv347 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape273), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape274 = R.call_tir(cls.reshape6, (lv347,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape275 = R.call_tir(cls.reshape7, (reshape274,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv348 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight5, gpt_neox_layers_4_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm141 = R.call_tir(cls.layer_norm1, (lv145, gpt_neox_layers_4_post_attention_layernorm_weight5, gpt_neox_layers_4_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv349 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv147 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm141, lv349, gpt_neox_layers_4_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv350 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv148 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv147, lv350, gpt_neox_layers_4_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv149 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape275, lv348, gpt_neox_layers_4_attention_dense_bias5, lv148, lv145), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm142 = R.call_tir(cls.layer_norm1, (lv149, gpt_neox_layers_5_input_layernorm_weight5, gpt_neox_layers_5_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv351 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight5, gpt_neox_layers_5_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv150 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm142, lv351, gpt_neox_layers_5_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape276 = R.call_tir(cls.reshape4, (lv150,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape277 = R.call_tir(cls.reshape5, (reshape276,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv352 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape277), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape278 = R.call_tir(cls.reshape6, (lv352,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape279 = R.call_tir(cls.reshape7, (reshape278,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv353 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight5, gpt_neox_layers_5_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm143 = R.call_tir(cls.layer_norm1, (lv149, gpt_neox_layers_5_post_attention_layernorm_weight5, gpt_neox_layers_5_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv354 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv151 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm143, lv354, gpt_neox_layers_5_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv355 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv152 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv151, lv355, gpt_neox_layers_5_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv153 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape279, lv353, gpt_neox_layers_5_attention_dense_bias5, lv152, lv149), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm144 = R.call_tir(cls.layer_norm1, (lv153, gpt_neox_layers_6_input_layernorm_weight5, gpt_neox_layers_6_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv356 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight5, gpt_neox_layers_6_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv154 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm144, lv356, gpt_neox_layers_6_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape280 = R.call_tir(cls.reshape4, (lv154,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape281 = R.call_tir(cls.reshape5, (reshape280,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv357 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape281), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape282 = R.call_tir(cls.reshape6, (lv357,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape283 = R.call_tir(cls.reshape7, (reshape282,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv358 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight5, gpt_neox_layers_6_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm145 = R.call_tir(cls.layer_norm1, (lv153, gpt_neox_layers_6_post_attention_layernorm_weight5, gpt_neox_layers_6_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv359 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv155 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm145, lv359, gpt_neox_layers_6_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv360 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv156 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv155, lv360, gpt_neox_layers_6_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv157 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape283, lv358, gpt_neox_layers_6_attention_dense_bias5, lv156, lv153), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm146 = R.call_tir(cls.layer_norm1, (lv157, gpt_neox_layers_7_input_layernorm_weight5, gpt_neox_layers_7_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv361 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight5, gpt_neox_layers_7_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv158 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm146, lv361, gpt_neox_layers_7_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape284 = R.call_tir(cls.reshape4, (lv158,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape285 = R.call_tir(cls.reshape5, (reshape284,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv362 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape285), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape286 = R.call_tir(cls.reshape6, (lv362,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape287 = R.call_tir(cls.reshape7, (reshape286,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv363 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight5, gpt_neox_layers_7_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm147 = R.call_tir(cls.layer_norm1, (lv157, gpt_neox_layers_7_post_attention_layernorm_weight5, gpt_neox_layers_7_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv364 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv159 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm147, lv364, gpt_neox_layers_7_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv365 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv160 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv159, lv365, gpt_neox_layers_7_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv161 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape287, lv363, gpt_neox_layers_7_attention_dense_bias5, lv160, lv157), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm148 = R.call_tir(cls.layer_norm1, (lv161, gpt_neox_layers_8_input_layernorm_weight5, gpt_neox_layers_8_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv366 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight5, gpt_neox_layers_8_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv162 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm148, lv366, gpt_neox_layers_8_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape288 = R.call_tir(cls.reshape4, (lv162,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape289 = R.call_tir(cls.reshape5, (reshape288,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv367 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape289), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape290 = R.call_tir(cls.reshape6, (lv367,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape291 = R.call_tir(cls.reshape7, (reshape290,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv368 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight5, gpt_neox_layers_8_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm149 = R.call_tir(cls.layer_norm1, (lv161, gpt_neox_layers_8_post_attention_layernorm_weight5, gpt_neox_layers_8_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv369 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv163 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm149, lv369, gpt_neox_layers_8_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv370 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv164 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv163, lv370, gpt_neox_layers_8_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv165 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape291, lv368, gpt_neox_layers_8_attention_dense_bias5, lv164, lv161), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm150 = R.call_tir(cls.layer_norm1, (lv165, gpt_neox_layers_9_input_layernorm_weight5, gpt_neox_layers_9_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv371 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight5, gpt_neox_layers_9_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv166 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm150, lv371, gpt_neox_layers_9_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape292 = R.call_tir(cls.reshape4, (lv166,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape293 = R.call_tir(cls.reshape5, (reshape292,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv372 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape293), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape294 = R.call_tir(cls.reshape6, (lv372,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape295 = R.call_tir(cls.reshape7, (reshape294,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv373 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight5, gpt_neox_layers_9_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm151 = R.call_tir(cls.layer_norm1, (lv165, gpt_neox_layers_9_post_attention_layernorm_weight5, gpt_neox_layers_9_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv374 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv167 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm151, lv374, gpt_neox_layers_9_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv375 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv168 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv167, lv375, gpt_neox_layers_9_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv169 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape295, lv373, gpt_neox_layers_9_attention_dense_bias5, lv168, lv165), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm152 = R.call_tir(cls.layer_norm1, (lv169, gpt_neox_layers_10_input_layernorm_weight5, gpt_neox_layers_10_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv376 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight5, gpt_neox_layers_10_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv170 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm152, lv376, gpt_neox_layers_10_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape296 = R.call_tir(cls.reshape4, (lv170,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape297 = R.call_tir(cls.reshape5, (reshape296,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv377 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape297), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape298 = R.call_tir(cls.reshape6, (lv377,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape299 = R.call_tir(cls.reshape7, (reshape298,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv378 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight5, gpt_neox_layers_10_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm153 = R.call_tir(cls.layer_norm1, (lv169, gpt_neox_layers_10_post_attention_layernorm_weight5, gpt_neox_layers_10_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv379 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv171 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm153, lv379, gpt_neox_layers_10_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv380 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv172 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv171, lv380, gpt_neox_layers_10_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv173 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape299, lv378, gpt_neox_layers_10_attention_dense_bias5, lv172, lv169), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm154 = R.call_tir(cls.layer_norm1, (lv173, gpt_neox_layers_11_input_layernorm_weight5, gpt_neox_layers_11_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv381 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight5, gpt_neox_layers_11_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv174 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm154, lv381, gpt_neox_layers_11_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape300 = R.call_tir(cls.reshape4, (lv174,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape301 = R.call_tir(cls.reshape5, (reshape300,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv382 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape301), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape302 = R.call_tir(cls.reshape6, (lv382,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape303 = R.call_tir(cls.reshape7, (reshape302,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv383 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight5, gpt_neox_layers_11_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm155 = R.call_tir(cls.layer_norm1, (lv173, gpt_neox_layers_11_post_attention_layernorm_weight5, gpt_neox_layers_11_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv384 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv175 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm155, lv384, gpt_neox_layers_11_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv385 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv176 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv175, lv385, gpt_neox_layers_11_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv177 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape303, lv383, gpt_neox_layers_11_attention_dense_bias5, lv176, lv173), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm156 = R.call_tir(cls.layer_norm1, (lv177, gpt_neox_layers_12_input_layernorm_weight5, gpt_neox_layers_12_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv386 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight5, gpt_neox_layers_12_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv178 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm156, lv386, gpt_neox_layers_12_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape304 = R.call_tir(cls.reshape4, (lv178,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape305 = R.call_tir(cls.reshape5, (reshape304,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv387 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape305), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape306 = R.call_tir(cls.reshape6, (lv387,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape307 = R.call_tir(cls.reshape7, (reshape306,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv388 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight5, gpt_neox_layers_12_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm157 = R.call_tir(cls.layer_norm1, (lv177, gpt_neox_layers_12_post_attention_layernorm_weight5, gpt_neox_layers_12_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv389 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv179 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm157, lv389, gpt_neox_layers_12_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv390 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv180 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv179, lv390, gpt_neox_layers_12_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv181 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape307, lv388, gpt_neox_layers_12_attention_dense_bias5, lv180, lv177), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm158 = R.call_tir(cls.layer_norm1, (lv181, gpt_neox_layers_13_input_layernorm_weight5, gpt_neox_layers_13_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv391 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight5, gpt_neox_layers_13_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv182 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm158, lv391, gpt_neox_layers_13_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape308 = R.call_tir(cls.reshape4, (lv182,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape309 = R.call_tir(cls.reshape5, (reshape308,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv392 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape309), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape310 = R.call_tir(cls.reshape6, (lv392,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape311 = R.call_tir(cls.reshape7, (reshape310,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv393 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight5, gpt_neox_layers_13_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm159 = R.call_tir(cls.layer_norm1, (lv181, gpt_neox_layers_13_post_attention_layernorm_weight5, gpt_neox_layers_13_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv394 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv183 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm159, lv394, gpt_neox_layers_13_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv395 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv184 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv183, lv395, gpt_neox_layers_13_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv185 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape311, lv393, gpt_neox_layers_13_attention_dense_bias5, lv184, lv181), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm160 = R.call_tir(cls.layer_norm1, (lv185, gpt_neox_layers_14_input_layernorm_weight5, gpt_neox_layers_14_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv396 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight5, gpt_neox_layers_14_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv186 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm160, lv396, gpt_neox_layers_14_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape312 = R.call_tir(cls.reshape4, (lv186,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape313 = R.call_tir(cls.reshape5, (reshape312,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv397 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape313), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape314 = R.call_tir(cls.reshape6, (lv397,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape315 = R.call_tir(cls.reshape7, (reshape314,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv398 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight5, gpt_neox_layers_14_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm161 = R.call_tir(cls.layer_norm1, (lv185, gpt_neox_layers_14_post_attention_layernorm_weight5, gpt_neox_layers_14_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv399 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv187 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm161, lv399, gpt_neox_layers_14_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv400 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv188 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv187, lv400, gpt_neox_layers_14_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv189 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape315, lv398, gpt_neox_layers_14_attention_dense_bias5, lv188, lv185), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm162 = R.call_tir(cls.layer_norm1, (lv189, gpt_neox_layers_15_input_layernorm_weight5, gpt_neox_layers_15_input_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv401 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight5, gpt_neox_layers_15_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv190 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm162, lv401, gpt_neox_layers_15_attention_query_key_value_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape316 = R.call_tir(cls.reshape4, (lv190,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape317 = R.call_tir(cls.reshape5, (reshape316,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv402 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape317), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape318 = R.call_tir(cls.reshape6, (lv402,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape319 = R.call_tir(cls.reshape7, (reshape318,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv403 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight5, gpt_neox_layers_15_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm163 = R.call_tir(cls.layer_norm1, (lv189, gpt_neox_layers_15_post_attention_layernorm_weight5, gpt_neox_layers_15_post_attention_layernorm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv404 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv191 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm163, lv404, gpt_neox_layers_15_mlp_dense_h_to_4h_bias5), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv405 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv192 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv191, lv405, gpt_neox_layers_15_mlp_dense_4h_to_h_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv193 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape319, lv403, gpt_neox_layers_15_attention_dense_bias5, lv192, lv189), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm164 = R.call_tir(cls.layer_norm1, (lv193, gpt_neox_final_layer_norm_weight5, gpt_neox_final_layer_norm_bias5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv406 = R.call_tir(cls.dequantize, (embed_out_q_weight5, embed_out_q_scale5), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) lv194 = R.call_tir(cls.fused_NT_matmul9_cast5, (layer_norm164, lv406), out_sinfo=R.Tensor((1, seq_len, vocab_size), dtype="float32")) gv5: R.Tuple(R.Tensor((1, seq_len, vocab_size), dtype="float32"), R.Object) = lv194, paged_kv_cache R.output(gv5) return gv5 @R.function def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object: max_batch_size = T.int64() max_total_seq_len = T.int64() prefill_chunk_size = T.int64() page_size = T.int64() support_sliding_window = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 16]), R.prim_value(8), R.prim_value(8), R.prim_value(256), R.prim_value(1), R.prim_value(1), R.prim_value(10000), R.const(0.0, "float16"), cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, R.prim_value(0), R.prim_value(0), sinfo_args=(R.Object,)) return paged_kv_cache @R.function def decode(input_embed: R.Tensor((1, 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, "vocab_size"), dtype="float32"), R.Object): vocab_size = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_layers_0_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight2: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale2: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm33 = R.call_tir(cls.layer_norm2, (input_embed, gpt_neox_layers_0_input_layernorm_weight2, gpt_neox_layers_0_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv83 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight2, gpt_neox_layers_0_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv195 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm33, lv83, gpt_neox_layers_0_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv196 = R.call_tir(cls.fused_reshape8_reshape9, (lv195,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv84 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), lv196), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv197 = R.call_tir(cls.fused_reshape10_reshape11, (lv84,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv85 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight2, gpt_neox_layers_0_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm34 = R.call_tir(cls.layer_norm2, (input_embed, gpt_neox_layers_0_post_attention_layernorm_weight2, gpt_neox_layers_0_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv86 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv198 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm34, lv86, gpt_neox_layers_0_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv87 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv199 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv198, lv87, gpt_neox_layers_0_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv200 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv197, lv85, gpt_neox_layers_0_attention_dense_bias2, lv199, input_embed), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm35 = R.call_tir(cls.layer_norm2, (lv200, gpt_neox_layers_1_input_layernorm_weight2, gpt_neox_layers_1_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv88 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight2, gpt_neox_layers_1_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv201 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm35, lv88, gpt_neox_layers_1_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv202 = R.call_tir(cls.fused_reshape8_reshape9, (lv201,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv89 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), lv202), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv203 = R.call_tir(cls.fused_reshape10_reshape11, (lv89,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv90 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight2, gpt_neox_layers_1_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm36 = R.call_tir(cls.layer_norm2, (lv200, gpt_neox_layers_1_post_attention_layernorm_weight2, gpt_neox_layers_1_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv91 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv204 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm36, lv91, gpt_neox_layers_1_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv92 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv205 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv204, lv92, gpt_neox_layers_1_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv206 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv203, lv90, gpt_neox_layers_1_attention_dense_bias2, lv205, lv200), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm37 = R.call_tir(cls.layer_norm2, (lv206, gpt_neox_layers_2_input_layernorm_weight2, gpt_neox_layers_2_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv93 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight2, gpt_neox_layers_2_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv207 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm37, lv93, gpt_neox_layers_2_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv208 = R.call_tir(cls.fused_reshape8_reshape9, (lv207,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv94 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), lv208), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv209 = R.call_tir(cls.fused_reshape10_reshape11, (lv94,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv95 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight2, gpt_neox_layers_2_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm38 = R.call_tir(cls.layer_norm2, (lv206, gpt_neox_layers_2_post_attention_layernorm_weight2, gpt_neox_layers_2_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv96 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv210 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm38, lv96, gpt_neox_layers_2_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv97 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv211 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv210, lv97, gpt_neox_layers_2_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv212 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv209, lv95, gpt_neox_layers_2_attention_dense_bias2, lv211, lv206), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm39 = R.call_tir(cls.layer_norm2, (lv212, gpt_neox_layers_3_input_layernorm_weight2, gpt_neox_layers_3_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv98 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight2, gpt_neox_layers_3_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv213 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm39, lv98, gpt_neox_layers_3_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv214 = R.call_tir(cls.fused_reshape8_reshape9, (lv213,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv99 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), lv214), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv215 = R.call_tir(cls.fused_reshape10_reshape11, (lv99,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv100 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight2, gpt_neox_layers_3_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm40 = R.call_tir(cls.layer_norm2, (lv212, gpt_neox_layers_3_post_attention_layernorm_weight2, gpt_neox_layers_3_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv101 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv216 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm40, lv101, gpt_neox_layers_3_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv102 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv217 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv216, lv102, gpt_neox_layers_3_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv218 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv215, lv100, gpt_neox_layers_3_attention_dense_bias2, lv217, lv212), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm41 = R.call_tir(cls.layer_norm2, (lv218, gpt_neox_layers_4_input_layernorm_weight2, gpt_neox_layers_4_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv103 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight2, gpt_neox_layers_4_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv219 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm41, lv103, gpt_neox_layers_4_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv220 = R.call_tir(cls.fused_reshape8_reshape9, (lv219,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv104 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), lv220), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv221 = R.call_tir(cls.fused_reshape10_reshape11, (lv104,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv105 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight2, gpt_neox_layers_4_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm42 = R.call_tir(cls.layer_norm2, (lv218, gpt_neox_layers_4_post_attention_layernorm_weight2, gpt_neox_layers_4_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv106 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv222 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm42, lv106, gpt_neox_layers_4_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv107 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv223 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv222, lv107, gpt_neox_layers_4_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv224 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv221, lv105, gpt_neox_layers_4_attention_dense_bias2, lv223, lv218), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm43 = R.call_tir(cls.layer_norm2, (lv224, gpt_neox_layers_5_input_layernorm_weight2, gpt_neox_layers_5_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv108 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight2, gpt_neox_layers_5_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv225 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm43, lv108, gpt_neox_layers_5_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv226 = R.call_tir(cls.fused_reshape8_reshape9, (lv225,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv109 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), lv226), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv227 = R.call_tir(cls.fused_reshape10_reshape11, (lv109,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv110 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight2, gpt_neox_layers_5_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm44 = R.call_tir(cls.layer_norm2, (lv224, gpt_neox_layers_5_post_attention_layernorm_weight2, gpt_neox_layers_5_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv111 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv228 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm44, lv111, gpt_neox_layers_5_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv112 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv229 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv228, lv112, gpt_neox_layers_5_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv230 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv227, lv110, gpt_neox_layers_5_attention_dense_bias2, lv229, lv224), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm45 = R.call_tir(cls.layer_norm2, (lv230, gpt_neox_layers_6_input_layernorm_weight2, gpt_neox_layers_6_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv113 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight2, gpt_neox_layers_6_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv231 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm45, lv113, gpt_neox_layers_6_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv232 = R.call_tir(cls.fused_reshape8_reshape9, (lv231,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv114 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), lv232), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv233 = R.call_tir(cls.fused_reshape10_reshape11, (lv114,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv115 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight2, gpt_neox_layers_6_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm46 = R.call_tir(cls.layer_norm2, (lv230, gpt_neox_layers_6_post_attention_layernorm_weight2, gpt_neox_layers_6_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv116 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv234 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm46, lv116, gpt_neox_layers_6_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv117 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv235 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv234, lv117, gpt_neox_layers_6_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv236 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv233, lv115, gpt_neox_layers_6_attention_dense_bias2, lv235, lv230), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm47 = R.call_tir(cls.layer_norm2, (lv236, gpt_neox_layers_7_input_layernorm_weight2, gpt_neox_layers_7_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv118 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight2, gpt_neox_layers_7_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv237 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm47, lv118, gpt_neox_layers_7_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv238 = R.call_tir(cls.fused_reshape8_reshape9, (lv237,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv119 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), lv238), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv239 = R.call_tir(cls.fused_reshape10_reshape11, (lv119,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv120 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight2, gpt_neox_layers_7_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm48 = R.call_tir(cls.layer_norm2, (lv236, gpt_neox_layers_7_post_attention_layernorm_weight2, gpt_neox_layers_7_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv121 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv240 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm48, lv121, gpt_neox_layers_7_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv122 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv241 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv240, lv122, gpt_neox_layers_7_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv242 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv239, lv120, gpt_neox_layers_7_attention_dense_bias2, lv241, lv236), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm49 = R.call_tir(cls.layer_norm2, (lv242, gpt_neox_layers_8_input_layernorm_weight2, gpt_neox_layers_8_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv123 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight2, gpt_neox_layers_8_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv243 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm49, lv123, gpt_neox_layers_8_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv244 = R.call_tir(cls.fused_reshape8_reshape9, (lv243,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv124 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), lv244), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv245 = R.call_tir(cls.fused_reshape10_reshape11, (lv124,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv125 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight2, gpt_neox_layers_8_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm50 = R.call_tir(cls.layer_norm2, (lv242, gpt_neox_layers_8_post_attention_layernorm_weight2, gpt_neox_layers_8_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv126 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv246 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm50, lv126, gpt_neox_layers_8_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv127 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv247 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv246, lv127, gpt_neox_layers_8_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv248 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv245, lv125, gpt_neox_layers_8_attention_dense_bias2, lv247, lv242), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm51 = R.call_tir(cls.layer_norm2, (lv248, gpt_neox_layers_9_input_layernorm_weight2, gpt_neox_layers_9_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv128 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight2, gpt_neox_layers_9_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv249 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm51, lv128, gpt_neox_layers_9_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv250 = R.call_tir(cls.fused_reshape8_reshape9, (lv249,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv129 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), lv250), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv251 = R.call_tir(cls.fused_reshape10_reshape11, (lv129,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv130 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight2, gpt_neox_layers_9_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm52 = R.call_tir(cls.layer_norm2, (lv248, gpt_neox_layers_9_post_attention_layernorm_weight2, gpt_neox_layers_9_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv131 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv252 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm52, lv131, gpt_neox_layers_9_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv132 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv253 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv252, lv132, gpt_neox_layers_9_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv254 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv251, lv130, gpt_neox_layers_9_attention_dense_bias2, lv253, lv248), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm53 = R.call_tir(cls.layer_norm2, (lv254, gpt_neox_layers_10_input_layernorm_weight2, gpt_neox_layers_10_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv133 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight2, gpt_neox_layers_10_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv255 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm53, lv133, gpt_neox_layers_10_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv256 = R.call_tir(cls.fused_reshape8_reshape9, (lv255,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv134 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), lv256), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv257 = R.call_tir(cls.fused_reshape10_reshape11, (lv134,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv135 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight2, gpt_neox_layers_10_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm54 = R.call_tir(cls.layer_norm2, (lv254, gpt_neox_layers_10_post_attention_layernorm_weight2, gpt_neox_layers_10_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv136 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv258 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm54, lv136, gpt_neox_layers_10_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv137 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv259 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv258, lv137, gpt_neox_layers_10_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv260 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv257, lv135, gpt_neox_layers_10_attention_dense_bias2, lv259, lv254), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm55 = R.call_tir(cls.layer_norm2, (lv260, gpt_neox_layers_11_input_layernorm_weight2, gpt_neox_layers_11_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv138 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight2, gpt_neox_layers_11_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv261 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm55, lv138, gpt_neox_layers_11_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv262 = R.call_tir(cls.fused_reshape8_reshape9, (lv261,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv139 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), lv262), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv263 = R.call_tir(cls.fused_reshape10_reshape11, (lv139,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv140 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight2, gpt_neox_layers_11_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm56 = R.call_tir(cls.layer_norm2, (lv260, gpt_neox_layers_11_post_attention_layernorm_weight2, gpt_neox_layers_11_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv141 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv264 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm56, lv141, gpt_neox_layers_11_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv142 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv265 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv264, lv142, gpt_neox_layers_11_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv266 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv263, lv140, gpt_neox_layers_11_attention_dense_bias2, lv265, lv260), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm57 = R.call_tir(cls.layer_norm2, (lv266, gpt_neox_layers_12_input_layernorm_weight2, gpt_neox_layers_12_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv143 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight2, gpt_neox_layers_12_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv267 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm57, lv143, gpt_neox_layers_12_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv268 = R.call_tir(cls.fused_reshape8_reshape9, (lv267,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv144 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), lv268), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv269 = R.call_tir(cls.fused_reshape10_reshape11, (lv144,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv145 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight2, gpt_neox_layers_12_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm58 = R.call_tir(cls.layer_norm2, (lv266, gpt_neox_layers_12_post_attention_layernorm_weight2, gpt_neox_layers_12_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv146 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv270 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm58, lv146, gpt_neox_layers_12_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv147 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv271 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv270, lv147, gpt_neox_layers_12_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv272 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv269, lv145, gpt_neox_layers_12_attention_dense_bias2, lv271, lv266), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm59 = R.call_tir(cls.layer_norm2, (lv272, gpt_neox_layers_13_input_layernorm_weight2, gpt_neox_layers_13_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv148 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight2, gpt_neox_layers_13_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv273 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm59, lv148, gpt_neox_layers_13_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv274 = R.call_tir(cls.fused_reshape8_reshape9, (lv273,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv149 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), lv274), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv275 = R.call_tir(cls.fused_reshape10_reshape11, (lv149,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv150 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight2, gpt_neox_layers_13_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm60 = R.call_tir(cls.layer_norm2, (lv272, gpt_neox_layers_13_post_attention_layernorm_weight2, gpt_neox_layers_13_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv151 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv276 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm60, lv151, gpt_neox_layers_13_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv152 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv277 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv276, lv152, gpt_neox_layers_13_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv278 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv275, lv150, gpt_neox_layers_13_attention_dense_bias2, lv277, lv272), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm61 = R.call_tir(cls.layer_norm2, (lv278, gpt_neox_layers_14_input_layernorm_weight2, gpt_neox_layers_14_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv153 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight2, gpt_neox_layers_14_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv279 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm61, lv153, gpt_neox_layers_14_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv280 = R.call_tir(cls.fused_reshape8_reshape9, (lv279,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv154 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), lv280), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv281 = R.call_tir(cls.fused_reshape10_reshape11, (lv154,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv155 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight2, gpt_neox_layers_14_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm62 = R.call_tir(cls.layer_norm2, (lv278, gpt_neox_layers_14_post_attention_layernorm_weight2, gpt_neox_layers_14_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv156 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv282 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm62, lv156, gpt_neox_layers_14_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv157 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv283 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv282, lv157, gpt_neox_layers_14_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv284 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv281, lv155, gpt_neox_layers_14_attention_dense_bias2, lv283, lv278), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm63 = R.call_tir(cls.layer_norm2, (lv284, gpt_neox_layers_15_input_layernorm_weight2, gpt_neox_layers_15_input_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv158 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight2, gpt_neox_layers_15_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv285 = R.call_tir(cls.fused_NT_matmul10_add10, (layer_norm63, lv158, gpt_neox_layers_15_attention_query_key_value_bias2), out_sinfo=R.Tensor((1, 1, 6144), dtype="float16")) lv286 = R.call_tir(cls.fused_reshape8_reshape9, (lv285,), out_sinfo=R.Tensor((1, 24, 256), dtype="float16")) lv159 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), lv286), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) lv287 = R.call_tir(cls.fused_reshape10_reshape11, (lv159,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv160 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight2, gpt_neox_layers_15_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm64 = R.call_tir(cls.layer_norm2, (lv284, gpt_neox_layers_15_post_attention_layernorm_weight2, gpt_neox_layers_15_post_attention_layernorm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv161 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv288 = R.call_tir(cls.fused_NT_matmul12_add12_gelu2_cast6, (layer_norm64, lv161, gpt_neox_layers_15_mlp_dense_h_to_4h_bias2), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16")) lv162 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv289 = R.call_tir(cls.fused_NT_matmul13_add13_cast7, (lv288, lv162, gpt_neox_layers_15_mlp_dense_4h_to_h_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv290 = R.call_tir(cls.fused_NT_matmul11_add11_add14_add14, (lv287, lv160, gpt_neox_layers_15_attention_dense_bias2, lv289, lv284), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) layer_norm65 = R.call_tir(cls.layer_norm2, (lv290, gpt_neox_final_layer_norm_weight2, gpt_neox_final_layer_norm_bias2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv163 = R.call_tir(cls.dequantize, (embed_out_q_weight2, embed_out_q_scale2), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) lv291 = R.call_tir(cls.fused_NT_matmul14_cast8, (layer_norm65, lv163), out_sinfo=R.Tensor((1, 1, vocab_size), dtype="float32")) gv2: R.Tuple(R.Tensor((1, 1, vocab_size), dtype="float32"), R.Object) = lv291, paged_kv_cache R.output(gv2) return gv2 @R.function def embed(input_ids: R.Tensor(("seq_len",), dtype="int32"), packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tensor(("seq_len", 2048), dtype="float16"): seq_len = T.int64() vocab_size = T.int64() R.func_attr({"num_input": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] lv = R.call_tir(cls.dequantize, (gpt_neox_embed_in_q_weight, gpt_neox_embed_in_q_scale), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) gv = R.call_tir(cls.take1, (lv, input_ids), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16")) R.output(gv) return gv @R.function def prefill(input_embed: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, "vocab_size"), dtype="float32"), R.Object): vocab_size = T.int64() seq_len = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight1: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale1: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm = R.call_tir(cls.layer_norm1, (input_embed, gpt_neox_layers_0_input_layernorm_weight1, gpt_neox_layers_0_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv1 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight1, gpt_neox_layers_0_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv292 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm, lv1, gpt_neox_layers_0_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape = R.call_tir(cls.reshape4, (lv292,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape1 = R.call_tir(cls.reshape5, (reshape,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv2 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape1), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape2 = R.call_tir(cls.reshape6, (lv2,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape3 = R.call_tir(cls.reshape7, (reshape2,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv3 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight1, gpt_neox_layers_0_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm1 = R.call_tir(cls.layer_norm1, (input_embed, gpt_neox_layers_0_post_attention_layernorm_weight1, gpt_neox_layers_0_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv4 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv293 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm1, lv4, gpt_neox_layers_0_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv5 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv294 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv293, lv5, gpt_neox_layers_0_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv295 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape3, lv3, gpt_neox_layers_0_attention_dense_bias1, lv294, input_embed), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm2 = R.call_tir(cls.layer_norm1, (lv295, gpt_neox_layers_1_input_layernorm_weight1, gpt_neox_layers_1_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv6 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight1, gpt_neox_layers_1_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv296 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm2, lv6, gpt_neox_layers_1_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape4 = R.call_tir(cls.reshape4, (lv296,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape5 = R.call_tir(cls.reshape5, (reshape4,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv7 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape5), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape6 = R.call_tir(cls.reshape6, (lv7,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape7 = R.call_tir(cls.reshape7, (reshape6,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv8 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight1, gpt_neox_layers_1_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm3 = R.call_tir(cls.layer_norm1, (lv295, gpt_neox_layers_1_post_attention_layernorm_weight1, gpt_neox_layers_1_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv9 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv297 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm3, lv9, gpt_neox_layers_1_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv10 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv298 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv297, lv10, gpt_neox_layers_1_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv299 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape7, lv8, gpt_neox_layers_1_attention_dense_bias1, lv298, lv295), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm4 = R.call_tir(cls.layer_norm1, (lv299, gpt_neox_layers_2_input_layernorm_weight1, gpt_neox_layers_2_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv11 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight1, gpt_neox_layers_2_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv300 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm4, lv11, gpt_neox_layers_2_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape8 = R.call_tir(cls.reshape4, (lv300,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape9 = R.call_tir(cls.reshape5, (reshape8,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv12 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape9), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape10 = R.call_tir(cls.reshape6, (lv12,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape11 = R.call_tir(cls.reshape7, (reshape10,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv13 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight1, gpt_neox_layers_2_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm5 = R.call_tir(cls.layer_norm1, (lv299, gpt_neox_layers_2_post_attention_layernorm_weight1, gpt_neox_layers_2_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv14 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv301 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm5, lv14, gpt_neox_layers_2_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv15 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv302 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv301, lv15, gpt_neox_layers_2_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv303 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape11, lv13, gpt_neox_layers_2_attention_dense_bias1, lv302, lv299), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm6 = R.call_tir(cls.layer_norm1, (lv303, gpt_neox_layers_3_input_layernorm_weight1, gpt_neox_layers_3_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv16 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight1, gpt_neox_layers_3_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv304 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm6, lv16, gpt_neox_layers_3_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape12 = R.call_tir(cls.reshape4, (lv304,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape13 = R.call_tir(cls.reshape5, (reshape12,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv17 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape13), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape14 = R.call_tir(cls.reshape6, (lv17,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape15 = R.call_tir(cls.reshape7, (reshape14,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv18 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight1, gpt_neox_layers_3_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm7 = R.call_tir(cls.layer_norm1, (lv303, gpt_neox_layers_3_post_attention_layernorm_weight1, gpt_neox_layers_3_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv19 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv305 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm7, lv19, gpt_neox_layers_3_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv20 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv306 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv305, lv20, gpt_neox_layers_3_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv307 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape15, lv18, gpt_neox_layers_3_attention_dense_bias1, lv306, lv303), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm8 = R.call_tir(cls.layer_norm1, (lv307, gpt_neox_layers_4_input_layernorm_weight1, gpt_neox_layers_4_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv21 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight1, gpt_neox_layers_4_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv308 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm8, lv21, gpt_neox_layers_4_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape16 = R.call_tir(cls.reshape4, (lv308,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape17 = R.call_tir(cls.reshape5, (reshape16,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv22 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape17), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape18 = R.call_tir(cls.reshape6, (lv22,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape19 = R.call_tir(cls.reshape7, (reshape18,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv23 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight1, gpt_neox_layers_4_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm9 = R.call_tir(cls.layer_norm1, (lv307, gpt_neox_layers_4_post_attention_layernorm_weight1, gpt_neox_layers_4_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv24 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv309 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm9, lv24, gpt_neox_layers_4_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv25 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv310 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv309, lv25, gpt_neox_layers_4_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv311 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape19, lv23, gpt_neox_layers_4_attention_dense_bias1, lv310, lv307), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm10 = R.call_tir(cls.layer_norm1, (lv311, gpt_neox_layers_5_input_layernorm_weight1, gpt_neox_layers_5_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv26 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight1, gpt_neox_layers_5_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv312 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm10, lv26, gpt_neox_layers_5_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape20 = R.call_tir(cls.reshape4, (lv312,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape21 = R.call_tir(cls.reshape5, (reshape20,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv27 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape21), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape22 = R.call_tir(cls.reshape6, (lv27,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape23 = R.call_tir(cls.reshape7, (reshape22,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv28 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight1, gpt_neox_layers_5_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm11 = R.call_tir(cls.layer_norm1, (lv311, gpt_neox_layers_5_post_attention_layernorm_weight1, gpt_neox_layers_5_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv29 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv313 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm11, lv29, gpt_neox_layers_5_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv30 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv314 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv313, lv30, gpt_neox_layers_5_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv315 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape23, lv28, gpt_neox_layers_5_attention_dense_bias1, lv314, lv311), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm12 = R.call_tir(cls.layer_norm1, (lv315, gpt_neox_layers_6_input_layernorm_weight1, gpt_neox_layers_6_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv31 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight1, gpt_neox_layers_6_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv316 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm12, lv31, gpt_neox_layers_6_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape24 = R.call_tir(cls.reshape4, (lv316,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape25 = R.call_tir(cls.reshape5, (reshape24,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv32 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape25), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape26 = R.call_tir(cls.reshape6, (lv32,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape27 = R.call_tir(cls.reshape7, (reshape26,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv33 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight1, gpt_neox_layers_6_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm13 = R.call_tir(cls.layer_norm1, (lv315, gpt_neox_layers_6_post_attention_layernorm_weight1, gpt_neox_layers_6_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv34 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv317 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm13, lv34, gpt_neox_layers_6_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv35 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv318 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv317, lv35, gpt_neox_layers_6_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv319 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape27, lv33, gpt_neox_layers_6_attention_dense_bias1, lv318, lv315), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm14 = R.call_tir(cls.layer_norm1, (lv319, gpt_neox_layers_7_input_layernorm_weight1, gpt_neox_layers_7_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv36 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight1, gpt_neox_layers_7_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv320 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm14, lv36, gpt_neox_layers_7_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape28 = R.call_tir(cls.reshape4, (lv320,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape29 = R.call_tir(cls.reshape5, (reshape28,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv37 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape29), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape30 = R.call_tir(cls.reshape6, (lv37,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape31 = R.call_tir(cls.reshape7, (reshape30,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv38 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight1, gpt_neox_layers_7_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm15 = R.call_tir(cls.layer_norm1, (lv319, gpt_neox_layers_7_post_attention_layernorm_weight1, gpt_neox_layers_7_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv39 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv321 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm15, lv39, gpt_neox_layers_7_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv40 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv322 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv321, lv40, gpt_neox_layers_7_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv323 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape31, lv38, gpt_neox_layers_7_attention_dense_bias1, lv322, lv319), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm16 = R.call_tir(cls.layer_norm1, (lv323, gpt_neox_layers_8_input_layernorm_weight1, gpt_neox_layers_8_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv41 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight1, gpt_neox_layers_8_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv324 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm16, lv41, gpt_neox_layers_8_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape32 = R.call_tir(cls.reshape4, (lv324,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape33 = R.call_tir(cls.reshape5, (reshape32,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv42 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape33), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape34 = R.call_tir(cls.reshape6, (lv42,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape35 = R.call_tir(cls.reshape7, (reshape34,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv43 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight1, gpt_neox_layers_8_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm17 = R.call_tir(cls.layer_norm1, (lv323, gpt_neox_layers_8_post_attention_layernorm_weight1, gpt_neox_layers_8_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv44 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv325 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm17, lv44, gpt_neox_layers_8_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv45 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv326 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv325, lv45, gpt_neox_layers_8_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv327 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape35, lv43, gpt_neox_layers_8_attention_dense_bias1, lv326, lv323), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm18 = R.call_tir(cls.layer_norm1, (lv327, gpt_neox_layers_9_input_layernorm_weight1, gpt_neox_layers_9_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv46 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight1, gpt_neox_layers_9_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv328 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm18, lv46, gpt_neox_layers_9_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape36 = R.call_tir(cls.reshape4, (lv328,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape37 = R.call_tir(cls.reshape5, (reshape36,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv47 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape37), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape38 = R.call_tir(cls.reshape6, (lv47,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape39 = R.call_tir(cls.reshape7, (reshape38,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv48 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight1, gpt_neox_layers_9_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm19 = R.call_tir(cls.layer_norm1, (lv327, gpt_neox_layers_9_post_attention_layernorm_weight1, gpt_neox_layers_9_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv49 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv329 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm19, lv49, gpt_neox_layers_9_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv50 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv330 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv329, lv50, gpt_neox_layers_9_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv331 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape39, lv48, gpt_neox_layers_9_attention_dense_bias1, lv330, lv327), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm20 = R.call_tir(cls.layer_norm1, (lv331, gpt_neox_layers_10_input_layernorm_weight1, gpt_neox_layers_10_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv51 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight1, gpt_neox_layers_10_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv332 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm20, lv51, gpt_neox_layers_10_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape40 = R.call_tir(cls.reshape4, (lv332,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape41 = R.call_tir(cls.reshape5, (reshape40,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv52 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape41), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape42 = R.call_tir(cls.reshape6, (lv52,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape43 = R.call_tir(cls.reshape7, (reshape42,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv53 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight1, gpt_neox_layers_10_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm21 = R.call_tir(cls.layer_norm1, (lv331, gpt_neox_layers_10_post_attention_layernorm_weight1, gpt_neox_layers_10_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv54 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv333 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm21, lv54, gpt_neox_layers_10_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv55 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv334 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv333, lv55, gpt_neox_layers_10_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv335 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape43, lv53, gpt_neox_layers_10_attention_dense_bias1, lv334, lv331), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm22 = R.call_tir(cls.layer_norm1, (lv335, gpt_neox_layers_11_input_layernorm_weight1, gpt_neox_layers_11_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv56 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight1, gpt_neox_layers_11_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv336 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm22, lv56, gpt_neox_layers_11_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape44 = R.call_tir(cls.reshape4, (lv336,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape45 = R.call_tir(cls.reshape5, (reshape44,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv57 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape45), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape46 = R.call_tir(cls.reshape6, (lv57,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape47 = R.call_tir(cls.reshape7, (reshape46,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv58 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight1, gpt_neox_layers_11_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm23 = R.call_tir(cls.layer_norm1, (lv335, gpt_neox_layers_11_post_attention_layernorm_weight1, gpt_neox_layers_11_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv59 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv337 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm23, lv59, gpt_neox_layers_11_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv60 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv338 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv337, lv60, gpt_neox_layers_11_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv339 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape47, lv58, gpt_neox_layers_11_attention_dense_bias1, lv338, lv335), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm24 = R.call_tir(cls.layer_norm1, (lv339, gpt_neox_layers_12_input_layernorm_weight1, gpt_neox_layers_12_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv61 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight1, gpt_neox_layers_12_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv340 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm24, lv61, gpt_neox_layers_12_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape48 = R.call_tir(cls.reshape4, (lv340,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape49 = R.call_tir(cls.reshape5, (reshape48,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv62 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape49), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape50 = R.call_tir(cls.reshape6, (lv62,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape51 = R.call_tir(cls.reshape7, (reshape50,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv63 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight1, gpt_neox_layers_12_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm25 = R.call_tir(cls.layer_norm1, (lv339, gpt_neox_layers_12_post_attention_layernorm_weight1, gpt_neox_layers_12_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv64 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv341 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm25, lv64, gpt_neox_layers_12_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv65 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv342 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv341, lv65, gpt_neox_layers_12_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv343 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape51, lv63, gpt_neox_layers_12_attention_dense_bias1, lv342, lv339), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm26 = R.call_tir(cls.layer_norm1, (lv343, gpt_neox_layers_13_input_layernorm_weight1, gpt_neox_layers_13_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv66 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight1, gpt_neox_layers_13_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv344 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm26, lv66, gpt_neox_layers_13_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape52 = R.call_tir(cls.reshape4, (lv344,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape53 = R.call_tir(cls.reshape5, (reshape52,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv67 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape53), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape54 = R.call_tir(cls.reshape6, (lv67,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape55 = R.call_tir(cls.reshape7, (reshape54,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv68 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight1, gpt_neox_layers_13_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm27 = R.call_tir(cls.layer_norm1, (lv343, gpt_neox_layers_13_post_attention_layernorm_weight1, gpt_neox_layers_13_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv69 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv345 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm27, lv69, gpt_neox_layers_13_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv70 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv346 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv345, lv70, gpt_neox_layers_13_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv347 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape55, lv68, gpt_neox_layers_13_attention_dense_bias1, lv346, lv343), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm28 = R.call_tir(cls.layer_norm1, (lv347, gpt_neox_layers_14_input_layernorm_weight1, gpt_neox_layers_14_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv71 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight1, gpt_neox_layers_14_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv348 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm28, lv71, gpt_neox_layers_14_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape56 = R.call_tir(cls.reshape4, (lv348,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape57 = R.call_tir(cls.reshape5, (reshape56,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv72 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape57), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape58 = R.call_tir(cls.reshape6, (lv72,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape59 = R.call_tir(cls.reshape7, (reshape58,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv73 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight1, gpt_neox_layers_14_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm29 = R.call_tir(cls.layer_norm1, (lv347, gpt_neox_layers_14_post_attention_layernorm_weight1, gpt_neox_layers_14_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv74 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv349 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm29, lv74, gpt_neox_layers_14_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv75 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv350 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv349, lv75, gpt_neox_layers_14_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv351 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape59, lv73, gpt_neox_layers_14_attention_dense_bias1, lv350, lv347), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm30 = R.call_tir(cls.layer_norm1, (lv351, gpt_neox_layers_15_input_layernorm_weight1, gpt_neox_layers_15_input_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv76 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight1, gpt_neox_layers_15_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) lv352 = R.call_tir(cls.fused_NT_matmul5_add5, (layer_norm30, lv76, gpt_neox_layers_15_attention_query_key_value_bias1), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16")) reshape60 = R.call_tir(cls.reshape4, (lv352,), out_sinfo=R.Tensor((1, seq_len, 24, 256), dtype="float16")) reshape61 = R.call_tir(cls.reshape5, (reshape60,), out_sinfo=R.Tensor((seq_len, 24, 256), dtype="float16")) lv77 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape61), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape62 = R.call_tir(cls.reshape6, (lv77,), out_sinfo=R.Tensor((1, seq_len, 8, 256), dtype="float16")) reshape63 = R.call_tir(cls.reshape7, (reshape62,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv78 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight1, gpt_neox_layers_15_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) layer_norm31 = R.call_tir(cls.layer_norm1, (lv351, gpt_neox_layers_15_post_attention_layernorm_weight1, gpt_neox_layers_15_post_attention_layernorm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv79 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) lv353 = R.call_tir(cls.fused_NT_matmul7_add7_gelu1_cast3, (layer_norm31, lv79, gpt_neox_layers_15_mlp_dense_h_to_4h_bias1), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16")) lv80 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) lv354 = R.call_tir(cls.fused_NT_matmul8_add8_cast4, (lv353, lv80, gpt_neox_layers_15_mlp_dense_4h_to_h_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv355 = R.call_tir(cls.fused_NT_matmul6_add6_add9_add9, (reshape63, lv78, gpt_neox_layers_15_attention_dense_bias1, lv354, lv351), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) layer_norm32 = R.call_tir(cls.layer_norm1, (lv355, gpt_neox_final_layer_norm_weight1, gpt_neox_final_layer_norm_bias1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16")) lv81 = R.call_tir(cls.index, (layer_norm32,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv82 = R.call_tir(cls.dequantize, (embed_out_q_weight1, embed_out_q_scale1), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) lv356 = R.call_tir(cls.fused_NT_matmul14_cast8, (lv81, lv82), out_sinfo=R.Tensor((1, 1, vocab_size), dtype="float32")) gv1: R.Tuple(R.Tensor((1, 1, vocab_size), dtype="float32"), R.Object) = lv356, paged_kv_cache R.output(gv1) return gv1 @R.function def softmax_with_temperature(logits: R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), temperature: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"): batch_size = T.int64(is_size_var=True) vocab_size = T.int64(is_size_var=True) R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): lv: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", logits, R.shape([batch_size, vocab_size]), sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),)) lv1 = R.call_tir(cls.chunk_lse, (lv, temperature), out_sinfo=[R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32"), R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32")]) lv2: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[0] lv3: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[1] lv4 = R.call_tir(cls.softmax_with_chunked_sum, (lv, temperature, lv2, lv3), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) gv: R.Tensor((batch_size, 1, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", lv4, R.shape([batch_size, 1, vocab_size]), sinfo_args=(R.Tensor((batch_size, 1, vocab_size), dtype="float32"),)) R.output(gv) return gv