# from tvm.script import ir as I # from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module class Module: I.module_attrs({"system_lib_prefix": "gpt_neox_q4f16_1_"}) @T.prim_func def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_seq = T.int32(is_size_var=True) seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") # with T.block("root"): for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 255) // 256, thread="blockIdx.x"): for fused_s_v_1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("block"): vs = T.axis.spatial(num_seq, (fused_s_v_0 * 256 + fused_s_v_1) // vocab_size) vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 256 + fused_s_v_1) % vocab_size) T.where(fused_s_v_0 * 256 + fused_s_v_1 < num_seq * vocab_size) T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv]) T.writes(logits[seq_ids[vs], vv]) logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0)) @T.prim_func def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_token = T.int32(is_size_var=True) pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") logit_bias = T.match_buffer(var_logit_bias, (num_token,)) # with T.block("root"): for p0 in T.thread_binding((num_token + 255) // 256, thread="blockIdx.x"): for p1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("block"): vp = T.axis.spatial(num_token, p0 * 256 + p1) T.where(p0 * 256 + p1 < num_token) T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp]) T.writes(logits[pos2seq_id[vp], token_ids[vp]]) logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp] @T.prim_func def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True) logits = T.match_buffer(var_logits, (batch_size, vocab_size)) num_seq = T.int32(is_size_var=True) seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") num_token = T.int32(is_size_var=True) pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") penalties = T.match_buffer(var_penalties, (num_seq, 3)) # with T.block("root"): for p0 in T.thread_binding((num_token + 255) // 256, thread="blockIdx.x"): for p1 in T.thread_binding(256, thread="threadIdx.x"): with T.block("block"): vp = T.axis.spatial(num_token, p0 * 256 + p1) T.where(p0 * 256 + p1 < num_token) T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp]) T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]) logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1]) logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2]) @T.prim_func def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) B = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, 8, 256), "float16") max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(pages_handle, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) output = T.match_buffer(output_handle, (B, 8, 256), "float16") lse = T.match_buffer(lse_handle, (B, 8)) # with T.block("root"): sm_scale: T.float32 = T.float32(0.090168440055560212) for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(1, thread="threadIdx.y"): for tx in T.thread_binding(64, thread="threadIdx.x"): for tz in T.thread_binding(4, thread="threadIdx.z"): with T.block("attn"): T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, fused_by_bz // 8 + ty + fused_by_bz % 8, tx * 4 - 128:tx * 4 - 128 + 260]) T.writes(output[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty]) Q_local = T.alloc_buffer((4,), "float16", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((4, 256), "float16", scope="shared") V_smem = T.alloc_buffer((4, 256), "float16", scope="shared") O_allreduce = T.alloc_buffer((4, 1, 256), scope="shared") md_allreduce = T.alloc_buffer((4, 1, 2), scope="shared") S_reduce_local = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") S_local = T.alloc_buffer((1,), scope="local") QK_local = T.alloc_buffer((4,), scope="local") V_local = T.alloc_buffer((4,), "float16", scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_prev = T.alloc_buffer((1,), scope="local") other_m = T.alloc_buffer((1,), scope="local") other_d = T.alloc_buffer((1,), scope="local") exp_mprev = T.alloc_buffer((1,), scope="local") exp_otherm = T.alloc_buffer((1,), scope="local") other_o = T.alloc_buffer((4,), scope="local") st_m = T.alloc_buffer((1,), scope="local") st_d = T.alloc_buffer((1,), scope="local") O_local = T.alloc_buffer((4,), scope="local") by: T.int32 = fused_by_bz % 8 bz: T.int32 = fused_by_bz // 8 batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0) st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for vec in T.vectorized(4): freq = T.float32() Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, Q[bx, by + bz + ty, tx * 4 + vec + 128] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), Q[bx, by + bz + ty, tx * 4 + vec]) for iterator in range((kv_chunk_len[0] + 3) // 4): tile_start_s: T.int32 = tz + ty tile_start_g: T.int32 = iterator * 4 + tz + ty for j in range(1): with T.block("KV_load"): T.reads() T.writes() row_g: T.int32 = tile_start_g + j if row_g < kv_chunk_len[0]: seq_offset: T.int32 = row_g page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 for vec in T.vectorized(4): freq = T.float32() K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, pages[page_no, 0, by, page_offset, tx * 4 + vec + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec]) V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] else: for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) T.tvm_storage_sync("shared") m_prev[0] = st_m[0] for j in range(1): for vec in T.vectorized(4): QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale S_reduce_local[0] = T.float32(0.0) for vec in T.unroll(4): S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) S_local[j] = T.float32(-50000.0) if iterator * 4 + tz + j < kv_chunk_len[0]: S_local[j] = t0[0] st_m[0] = T.max(st_m[0], S_local[j]) o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) st_d[0] = st_d[0] * o_scale for j in range(1): S_local[j] = T.exp2(S_local[j] - st_m[0]) st_d[0] = st_d[0] + S_local[j] for j in T.vectorized(4): O_local[j] = O_local[j] * o_scale for j in range(1): for vec in T.vectorized(4): V_local[vec] = V_smem[tz + j, tx * 4 + vec] for vec in T.vectorized(4): O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] for vec in T.vectorized(4): O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] md_allreduce[tz, ty, 0] = st_m[0] md_allreduce[tz, ty, 1] = st_d[0] T.tvm_storage_sync("shared") st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for j in range(4): m_prev[0] = st_m[0] d_prev[0] = st_d[0] other_m[0] = md_allreduce[j, ty, 0] other_d[0] = md_allreduce[j, ty, 1] for vec in T.vectorized(4): other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) for vec in T.vectorized(4): O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] for vec in T.vectorized(4): O_local[vec] = O_local[vec] / st_d[0] for vec in T.vectorized(4): output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) @T.prim_func def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) B = T.int32(is_size_var=True) Q = T.match_buffer(Q_handle, (B, 8, 256), "float16") max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(pages_handle, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1) output = T.match_buffer(output_handle, (B, 8, 256), "float16") lse = T.match_buffer(lse_handle, (B, 8)) # with T.block("root"): sm_scale: T.float32 = T.float32(0.090168440055560212) for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(1, thread="threadIdx.y"): for tx in T.thread_binding(64, thread="threadIdx.x"): for tz in T.thread_binding(4, thread="threadIdx.z"): with T.block("attn"): T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, fused_by_bz // 8 + ty + fused_by_bz % 8, tx * 4 - 128:tx * 4 - 128 + 260]) T.writes(output[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 8 + fused_by_bz // 8 + ty]) Q_local = T.alloc_buffer((4,), "float16", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") K_smem = T.alloc_buffer((4, 256), "float16", scope="shared") V_smem = T.alloc_buffer((4, 256), "float16", scope="shared") O_allreduce = T.alloc_buffer((4, 1, 256), scope="shared") md_allreduce = T.alloc_buffer((4, 1, 2), scope="shared") S_reduce_local = T.alloc_buffer((1,), scope="local") t0 = T.alloc_buffer((1,), scope="local") S_local = T.alloc_buffer((1,), scope="local") QK_local = T.alloc_buffer((4,), scope="local") V_local = T.alloc_buffer((4,), "float16", scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_prev = T.alloc_buffer((1,), scope="local") other_m = T.alloc_buffer((1,), scope="local") other_d = T.alloc_buffer((1,), scope="local") exp_mprev = T.alloc_buffer((1,), scope="local") exp_otherm = T.alloc_buffer((1,), scope="local") other_o = T.alloc_buffer((4,), scope="local") st_m = T.alloc_buffer((1,), scope="local") st_d = T.alloc_buffer((1,), scope="local") O_local = T.alloc_buffer((4,), scope="local") by: T.int32 = fused_by_bz % 8 bz: T.int32 = fused_by_bz // 8 batch_idx: T.int32 = bx cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx] cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0) st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for vec in T.vectorized(4): freq = T.float32() Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, Q[bx, by + bz + ty, tx * 4 + vec + 128] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), Q[bx, by + bz + ty, tx * 4 + vec]) for iterator in range((kv_chunk_len[0] + 3) // 4): tile_start_s: T.int32 = tz + ty tile_start_g: T.int32 = iterator * 4 + tz + ty for j in range(1): with T.block("KV_load"): T.reads() T.writes() row_g: T.int32 = tile_start_g + j if row_g < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx]) page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 for vec in T.vectorized(4): freq = T.float32() K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 128, pages[page_no, 0, by, page_offset, tx * 4 + vec + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec]) V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec] else: for vec in T.vectorized(4): K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0) T.tvm_storage_sync("shared") m_prev[0] = st_m[0] for j in range(1): for vec in T.vectorized(4): QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale S_reduce_local[0] = T.float32(0.0) for vec in T.unroll(4): S_reduce_local[0] = S_reduce_local[0] + QK_local[vec] with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) T.writes(t0[0]) T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx) S_local[j] = T.float32(-50000.0) if iterator * 4 + tz + j < kv_chunk_len[0]: S_local[j] = t0[0] st_m[0] = T.max(st_m[0], S_local[j]) o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0]) st_d[0] = st_d[0] * o_scale for j in range(1): S_local[j] = T.exp2(S_local[j] - st_m[0]) st_d[0] = st_d[0] + S_local[j] for j in T.vectorized(4): O_local[j] = O_local[j] * o_scale for j in range(1): for vec in T.vectorized(4): V_local[vec] = V_smem[tz + j, tx * 4 + vec] for vec in T.vectorized(4): O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j] for vec in T.vectorized(4): O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec] md_allreduce[tz, ty, 0] = st_m[0] md_allreduce[tz, ty, 1] = st_d[0] T.tvm_storage_sync("shared") st_m[0] = T.float32(-50000.0) st_d[0] = T.float32(1.0) for vec in T.vectorized(4): O_local[vec] = T.float32(0.0) for j in range(4): m_prev[0] = st_m[0] d_prev[0] = st_d[0] other_m[0] = md_allreduce[j, ty, 0] other_d[0] = md_allreduce[j, ty, 1] for vec in T.vectorized(4): other_o[vec] = O_allreduce[j, ty, tx * 4 + vec] st_m[0] = T.max(st_m[0], other_m[0]) st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0]) exp_mprev[0] = T.exp2(m_prev[0] - st_m[0]) exp_otherm[0] = T.exp2(other_m[0] - st_m[0]) for vec in T.vectorized(4): O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0] for vec in T.vectorized(4): O_local[vec] = O_local[vec] / st_d[0] for vec in T.vectorized(4): output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec]) lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0]) @T.prim_func def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (total_len, 8)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((16, 256), "float16", scope="shared") V_smem = T.alloc_buffer((16, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 16), scope="shared") S_local = T.alloc_buffer((16, 16), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 freq = T.float32() K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, pages[page_no, 0, by, page_offset, j + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, j]) else: K_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:16, 0:256]) T.writes(S_local[0:16, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(1, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 1, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(1, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 16: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:16], V_smem[0:16, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 8, 16, 256), "float16", offset_factor=1) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (total_len, 8)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((16, 256), "float16", scope="shared") V_smem = T.alloc_buffer((16, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 16), scope="shared") S_local = T.alloc_buffer((16, 16), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 freq = T.float32() K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, pages[page_no, 0, by, page_offset, j + 128] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), pages[page_no, 0, by, page_offset, j]) else: K_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(8): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(16, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx]) page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16] page_offset: T.int32 = seq_offset % 16 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:16, 0:256]) T.writes(S_local[0:16, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(1, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 1, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) k = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(1, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 16: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:16], V_smem[0:16, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(2, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k = T.axis.reduce(16, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) qo_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) kv_len = T.int32(is_size_var=True) k = T.match_buffer(var_k, (kv_len, 8, 256), "float16") v = T.match_buffer(var_v, (kv_len, 8, 256), "float16") kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) output = T.match_buffer(var_output, (qo_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (qo_len, 8)) # with T.block("root"): for lbx in T.thread_binding(8, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((256, 16), "float16", scope="shared") V_smem = T.alloc_buffer((16, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 16), scope="shared") S_local = T.alloc_buffer((16, 16), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] q_indptr_val: T.int32 = q_indptr[b_idx] LH_start: T.int32 = tile_id[0] * 16 kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1 in range(4): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1_0 * 8 + lj_1_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_0, lj_0_0 in T.grid(2, 2): for lj_1 in T.vectorized(8): with T.block("Q_load"): i = T.axis.spatial(16, li_0 * 8 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 16) j = T.axis.spatial(256, lj_0_0 * 128 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 16 * 8 + lj_1) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 15) // 16): L_kv_start: T.int32 = iterator_1 * 16 L_kv_base: T.int32 = kv_indptr[b_idx] for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lz_0, ly_0_0 in T.grid(2, 2): for ly_1 in T.vectorized(8): with T.block("K_load"): i = T.axis.spatial(16, lz_0 * 8 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 16) j = T.axis.spatial(256, ly_0_0 * 128 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 16 * 8 + ly_1) T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 128:j - 128 + 257]) T.writes(K_smem[j, i]) cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: freq = T.float32() K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, k[L_kv_base + cur_L, by, j + 128] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 128]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), k[L_kv_base + cur_L, by, j]) else: K_smem[j, i] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lz_0, ly_0_0 in T.grid(2, 2): for ly_1 in T.vectorized(8): with T.block("V_load"): i = T.axis.spatial(16, lz_0 * 8 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 16) j = T.axis.spatial(256, ly_0_0 * 128 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 16 * 8 + ly_1) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: V_smem[i, j] = v[L_kv_base + cur_L, by, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:256, 0:16]) T.writes(S_local[0:16, 0:16]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init in T.unroll(1): for lj_1_0_init in T.unroll(1): for lj_1_1_init in T.vectorized(2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 + li_1_init) j = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 2 + lj_1_0_init * 2 + lj_1_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0 in range(16): for li_1 in T.unroll(1): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(2): for lk_1 in range(16): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1_0 * 2 + lj_1_1) k_1 = T.axis.reduce(256, lk_0 * 16 + lk_1) T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1 in range(1): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 + li_1) j = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 2 + lj_1_0 * 2 + lj_1_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:16], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(16): if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:16], m_new[i]) T.writes(S_smem[row, 0:16]) for j in range(16): if row < 16: row_: T.int32 = LH_start + row if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:16], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(16): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:16], V_smem[0:16, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init in T.unroll(4): for lj_1_0_init in T.unroll(1): for lj_1_1_init in T.vectorized(8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_0_init * 8 + lj_1_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1 in T.grid(1, 16): for li_1 in T.unroll(4): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1_0 * 8 + lj_1_1) k_1 = T.axis.reduce(16, lk_0 * 16 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1 in range(4): for lj_1_0 in T.unroll(1): for lj_1_1 in T.vectorized(8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1_0 * 8 + lj_1_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 8 @T.prim_func def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) qo_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (qo_len, 8, 256), "float16") q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) kv_len = T.int32(is_size_var=True) k = T.match_buffer(var_k, (kv_len, 8, 256), "float16") v = T.match_buffer(var_v, (kv_len, 8, 256), "float16") kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1) mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1) tree_size = T.int32(is_size_var=True) mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1) output = T.match_buffer(var_output, (qo_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (qo_len, 8)) # with T.block("root"): for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((32, 256), "float16", scope="shared") V_smem = T.alloc_buffer((32, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 32), scope="shared") S_local = T.alloc_buffer((16, 32), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 31) // 32): L_kv_start: T.int32 = iterator_1 * 32 L_kv_base: T.int32 = kv_indptr[b_idx] for lz_ly_fused_0 in range(16): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("KV_load"): i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_base + L_kv_start + i if L_kv_start + i < kv_chunk_len[0]: freq = T.float32() K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, k[cur_L, by, j + 128] * T.float16(-1.0), k[cur_L, by, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), k[cur_L, by, j]) V_smem[i, j] = v[cur_L, by, j] else: K_smem[i, j] = T.float16(0.0) V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:32, 0:256]) T.writes(S_local[0:16, 0:32]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 2 + li_1_init) j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) k_1 = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(32): if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i]) T.writes(S_smem[row, 0:32]) for j in range(32): if row < 16: row_: T.int32 = LH_start + row if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(32): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:32], V_smem[0:32, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @T.prim_func def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) temperature = T.match_buffer(var_temperature, (batch_size,)) num_chunks = T.int64(is_size_var=True) chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) # with T.block("root"): A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096))) temp_max = T.alloc_buffer((batch_size, num_chunks)) temp_sum = T.alloc_buffer((batch_size, num_chunks)) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): with T.block("pad"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2]) T.writes(A_pad[v0, v1, v2]) A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0)) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): with T.block("max"): v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) T.reads(A_pad[v0, v1, v2]) T.writes(temp_max[v0, v1]) with T.init(): temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0) temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2]) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)): with T.block("sum_exp"): v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2]) T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1]) T.writes(temp_sum[v0, v1]) with T.init(): temp_sum[v0, v1] = T.float32(0.0) temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0)) for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)): with T.block("log"): v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2]) T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1]) T.writes(chunked_sum[v0, v1], chunked_max[v0, v1]) chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1]) chunked_max[v0, v1] = temp_max[v0, v1] @T.prim_func def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) num_pages = T.int32() pages = T.match_buffer(var_pages, (num_pages, 2, 8, 16, 256), "float16", offset_factor=1) copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1) total_copy_length = T.int32() copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1) with T.block("root"): T.reads() T.writes() for bhd_o in T.thread_binding(batch_size * 8, thread="blockIdx.x"): for bhd_i in T.thread_binding(256, thread="threadIdx.x"): b: T.int32 = (bhd_o * 256 + bhd_i) // 2048 h: T.int32 = (bhd_o * 256 + bhd_i) // 256 % 8 d: T.int32 = (bhd_o * 256 + bhd_i) % 256 if bhd_o * 256 + bhd_i < batch_size * 8 * 256: for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]): src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d] pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d] @T.prim_func def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) num_pages, page_size = T.int32(), T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, 8, page_size, 256), "float16", offset_factor=1) # with T.block("root"): for b in T.thread_binding(copy_length * T.int64(8), thread="blockIdx.x"): for t in T.thread_binding(256, thread="threadIdx.x"): with T.block("copy"): vh = T.axis.spatial(8, T.Cast("int32", (b * T.int64(256) + T.Cast("int64", t)) // (copy_length * T.int64(256)))) vp = T.axis.spatial(copy_length, (b * T.int64(256) + T.Cast("int64", t)) % (copy_length * T.int64(256)) // T.int64(256)) vd = T.axis.spatial(256, T.Cast("int32", (b * T.int64(256) + T.Cast("int64", t)) % T.int64(256))) T.where(b * T.int64(256) + T.Cast("int64", t) < copy_length * T.int64(8) * T.int64(256)) T.reads(pages[src_page_id, 0:2, vh, vp, vd]) T.writes(pages[tgt_page_id, 0:2, vh, vp, vd]) pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] @T.prim_func(private=True) def dequantize(var_gpt_neox_embed_in_q_weight: T.handle, var_gpt_neox_embed_in_q_scale: T.handle, var_dequantize: T.handle): T.func_attr({"target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) vocab_size = T.int64() gpt_neox_embed_in_q_weight = T.match_buffer(var_gpt_neox_embed_in_q_weight, (vocab_size, T.int64(256)), "uint32") gpt_neox_embed_in_q_scale = T.match_buffer(var_gpt_neox_embed_in_q_scale, (vocab_size, T.int64(64)), "float16") dequantize = T.match_buffer(var_dequantize, (vocab_size, T.int64(2048)), "float16") # with T.block("root"): compute = T.alloc_buffer((vocab_size, T.int64(2048)), "float16") for i0, i1 in T.grid(vocab_size, T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_embed_in_q_weight[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_embed_in_q_weight[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(vocab_size, T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_embed_in_q_scale[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_embed_in_q_scale[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize1(gpt_neox_layers_0_attention_query_key_value_q_weight1: T.Buffer((T.int64(6144), T.int64(256)), "uint32"), gpt_neox_layers_0_attention_query_key_value_q_scale1: T.Buffer((T.int64(6144), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(6144), T.int64(2048)), "float16")): T.func_attr({"target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16") for i0, i1 in T.grid(T.int64(6144), T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_attention_query_key_value_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_attention_query_key_value_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(6144), T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_attention_query_key_value_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_attention_query_key_value_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize2(gpt_neox_layers_0_attention_dense_q_weight1: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), gpt_neox_layers_0_attention_dense_q_scale1: T.Buffer((T.int64(2048), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(2048)), "float16")): T.func_attr({"target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16") for i0, i1 in T.grid(T.int64(2048), T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_attention_dense_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_attention_dense_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(2048), T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_attention_dense_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_attention_dense_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize3(gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1: T.Buffer((T.int64(8192), T.int64(256)), "uint32"), gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1: T.Buffer((T.int64(8192), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(8192), T.int64(2048)), "float16")): T.func_attr({"target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(8192), T.int64(2048)), "float16") for i0, i1 in T.grid(T.int64(8192), T.int64(2048)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(8192), T.int64(2048)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func(private=True) def dequantize4(gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1: T.Buffer((T.int64(2048), T.int64(1024)), "uint32"), gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1: T.Buffer((T.int64(2048), T.int64(256)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(8192)), "float16")): T.func_attr({"target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) # with T.block("root"): compute = T.alloc_buffer((T.int64(2048), T.int64(8192)), "float16") for i0, i1 in T.grid(T.int64(2048), T.int64(8192)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1[v_i0, v_i1 // T.int64(8)]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15))) for i0, i1 in T.grid(T.int64(2048), T.int64(8192)): with T.block("dequantize"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(compute[v_i0, v_i1], gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1[v_i0, v_i1 // T.int64(32)]) T.writes(dequantize[v_i0, v_i1]) dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1[v_i0, v_i1 // T.int64(32)] @T.prim_func def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32): T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) seq_len = T.int32() qkv = T.match_buffer(var_qkv, (seq_len, 24, 256), "float16") position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1) q = T.match_buffer(var_q, (seq_len, 8, 256), "float16") k = T.match_buffer(var_k, (seq_len, 8, 256), "float16") v = T.match_buffer(var_v, (seq_len, 8, 256), "float16") # with T.block("root"): for iters_0, iters_1, iters_2 in T.grid(seq_len, 24, 256): with T.block("llama_fused_rope"): s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2]) T.reads(position_map[s], qkv[s, h, d - 32:d - 32 + 65]) T.writes(q[s, h, d], k[s, h - 8, d], v[s, h - 16, d]) if h < 8: freq = T.float32() q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1.0), qkv[s, h, d - 32]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(10000.0), T.Cast("float32", d * 2 % 64) / T.float32(64.0))}), qkv[s, h, d]) else: if h < 16: freq = T.float32() k[s, h - 8, d] = T.if_then_else(apply_rope > 0 and d < 64, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 32, qkv[s, h, d + 32] * T.float16(-1.0), qkv[s, h, d - 32]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(10000.0), T.Cast("float32", d * 2 % 64) / T.float32(64.0))}), qkv[s, h, d]) else: v[s, h - 16, d] = qkv[s, h, d] @T.prim_func def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) m, n = T.int32(is_size_var=True), T.int32(is_size_var=True) src = T.match_buffer(var_src, (m, n)) batch_size = T.int32(is_size_var=True) indices = T.match_buffer(var_indices, (batch_size,), "int32") dst = T.match_buffer(var_dst, (batch_size, n)) # with T.block("root"): for b, j in T.grid(batch_size, n): with T.block("gather_2d"): vb, vj = T.axis.remap("SS", [b, j]) T.reads(src[indices[vb], vj], indices[vb]) T.writes(dst[vb, vj]) dst[vb, vj] = src[indices[vb], vj] @T.prim_func(private=True) def index(var_layer_norm32: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")): T.func_attr({"target": T.target({"keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) seq_len = T.int64() layer_norm32 = T.match_buffer(var_layer_norm32, (T.int64(1), seq_len, T.int64(2048)), "float16") # with T.block("root"): for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.block("index"): v_i, v__, v_k = T.axis.remap("SSS", [i, _, k]) T.reads(layer_norm32[v_i, seq_len - T.int64(1), v_k]) T.writes(index[v_i, v__, v_k]) index[v_i, v__, v_k] = layer_norm32[v_i, seq_len - T.int64(1), v_k] @T.prim_func def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True) V = T.match_buffer(v, (N, H, D), "float16") S = T.match_buffer(s, (N, H)) V_other = T.match_buffer(v_other, (N, H, D), "float16") S_other = T.match_buffer(s_other, (N, H)) # with T.block("root"): for bx in T.thread_binding(N, thread="blockIdx.x"): for by in T.thread_binding(2, thread="blockIdx.y"): for ty in T.thread_binding(4, thread="threadIdx.y"): for tx in T.thread_binding(64, thread="threadIdx.x"): with T.block("merge"): T.reads(S[bx, ty + by * 4], S_other[bx, ty + by * 4], V[bx, ty + by * 4, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 4, tx * 4:tx * 4 + 4]) T.writes(V[bx, ty + by * 4, tx * 4:tx * 4 + 4], S[bx, ty + by * 4]) s_val = T.alloc_buffer((1,), scope="local") s_other_val = T.alloc_buffer((1,), scope="local") s_max = T.alloc_buffer((1,), scope="local") scale = T.alloc_buffer((1,), scope="local") other_scale = T.alloc_buffer((1,), scope="local") v_vec = T.alloc_buffer((4,), "float16", scope="local") v_other_vec = T.alloc_buffer((4,), "float16", scope="local") s_val[0] = S[bx, ty + by * 4] s_other_val[0] = S_other[bx, ty + by * 4] s_max[0] = T.max(s_val[0], s_other_val[0]) s_val[0] = T.exp2(s_val[0] - s_max[0]) s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) for vec in T.vectorized(4): v_vec[vec] = V[bx, ty + by * 4, tx * 4 + vec] for vec in T.vectorized(4): v_other_vec[vec] = V_other[bx, ty + by * 4, tx * 4 + vec] for vec in range(4): v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0]) for vec in T.vectorized(4): V[bx, ty + by * 4, tx * 4 + vec] = v_vec[vec] S[bx, ty + by * 4] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] @T.prim_func def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True) src = T.match_buffer(var_src, (batch_size, n)) indices = T.match_buffer(var_indices, (batch_size,), "int32") m = T.int32(is_size_var=True) dst = T.match_buffer(var_dst, (m, n)) # with T.block("root"): for b, j in T.grid(batch_size, n): with T.block("scatter_2d"): vb, vj = T.axis.remap("SS", [b, j]) T.reads(src[vb, vj], indices[vb]) T.writes(dst[indices[vb], vj]) dst[indices[vb], vj] = src[vb, vj] @T.prim_func def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) temperature = T.match_buffer(var_temperature, (batch_size,)) num_chunks = T.int64(is_size_var=True) chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks)) chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks)) softmax = T.match_buffer(var_softmax, (batch_size, vocab_size)) # with T.block("root"): temp_max_shared = T.alloc_buffer((batch_size,), scope="shared") temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared") for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"): for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): with T.block("max"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) T.reads(chunked_max[v0, v1]) T.writes(temp_max_shared[v0]) with T.init(): temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0) temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1]) for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): with T.block("sum_exp"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1) T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks) T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0]) T.writes(temp_sum_shared[v0]) with T.init(): temp_sum_shared[v0] = T.float32(0.0) temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1]) for l2_0 in T.serial(T.int64(16), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}): for l2_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): with T.block("log_pad"): v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks) v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks) v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(256) + l2_1 * T.int64(32) + l2_2) T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0]) T.writes(softmax[v0, v1 * T.int64(4096) + v2]) if v1 * T.int64(4096) + v2 < vocab_size: softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0]) @T.prim_func def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) num_pages, page_size = T.int64(), T.int64(is_size_var=True) pages = T.match_buffer(var_pages, (num_pages, 2, 8, page_size, 256), "float16", offset_factor=1) seqlen = T.int64(is_size_var=True) position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1) k_data = T.match_buffer(var_k_data, (16, seqlen, 8, 256), "float16") v_data = T.match_buffer(var_v_data, (16, seqlen, 8, 256), "float16") # with T.block("root"): for p, h, d in T.grid(seqlen, 8, 256): with T.block("copy0"): vp, vh, vd = T.axis.remap("SSS", [p, h, d]) T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd]) T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) position: T.int32 = position_map[vp] k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd] v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd] @T.prim_func def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.noalias": T.bool(True)}) num_pages = T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, 8, 16, 256), "float16", offset_factor=1) ntoken = T.int64(is_size_var=True) k_data = T.match_buffer(var_k_data, (ntoken, 8, 256), "float16") v_data = T.match_buffer(var_v_data, (ntoken, 8, 256), "float16") position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1) # with T.block("root"): for global_pos, h, f in T.grid(ntoken, 8, 256): if position_map[global_pos] != -1: with T.block("k_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf] with T.block("v_transpose_append"): vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) position: T.int32 = position_map[vgpos] pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf] @T.prim_func def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle): T.func_attr({"target": T.target({"host": {"keys": ["arm_cpu", "cpu"], "kind": "llvm", "mtriple": "aarch64-linux-android", "tag": ""}, "keys": ["opencl", "gpu"], "kind": "opencl", "max_function_args": 128, "max_num_threads": 256, "max_shared_memory_per_block": 16384, "max_threads_per_block": 256, "tag": "", "texture_spatial_limit": 16384, "thread_warp_size": 1}), "tir.is_scheduled": 1}) total_len = T.int32(is_size_var=True) q = T.match_buffer(var_q, (total_len, 8, 256), "float16") batch_size = T.int32(is_size_var=True) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1) max_num_pages = T.int32(is_size_var=True) pages = T.match_buffer(var_pages, (max_num_pages, 2, 8, 32, 256), "float16") page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1) nnz_pages = T.int32(is_size_var=True) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1) length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1) q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1) output = T.match_buffer(var_output, (total_len, 8, 256), "float16") lse = T.match_buffer(var_lse, (total_len, 8)) tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1) total_tree_order_len = T.int32(is_size_var=True) tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1) # with T.block("root"): assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention." for lbx in T.thread_binding(16, thread="blockIdx.x"): for lby in T.thread_binding(8, thread="blockIdx.y"): for lty in T.thread_binding(4, thread="threadIdx.y"): for ltx in T.thread_binding(32, thread="threadIdx.x"): with T.block("attn"): bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx]) T.reads() T.writes() tile_id = T.alloc_buffer((1,), "int32", scope="local") batch_idx = T.alloc_buffer((1,), "int32", scope="local") batch_tiles = T.alloc_buffer((1,), "int32", scope="local") batch_rows = T.alloc_buffer((1,), "int32", scope="local") iterator = T.alloc_buffer((1,), "int32", scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") Q_smem = T.alloc_buffer((16, 256), "float16", scope="shared") K_smem = T.alloc_buffer((32, 256), "float16", scope="shared") V_smem = T.alloc_buffer((32, 256), "float16", scope="shared") S_smem = T.alloc_buffer((16, 32), scope="shared") S_local = T.alloc_buffer((16, 32), scope="local") O_local = T.alloc_buffer((16, 256), scope="local") m_smem = T.alloc_buffer((16,), scope="shared") m_prev_smem = T.alloc_buffer((16,), scope="shared") d_smem = T.alloc_buffer((16,), scope="shared") m_new = T.alloc_buffer((1,), scope="local") m_prev = T.alloc_buffer((1,), scope="local") d_new = T.alloc_buffer((1,), scope="local") tile_id[0] = bx batch_idx[0] = 0 batch_rows[0] = q_indptr[1] - q_indptr[0] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 while T.tvm_thread_invariant(batch_idx[0] < batch_size): while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size: tile_id[0] = tile_id[0] - batch_tiles[0] batch_idx[0] = batch_idx[0] + 1 if batch_idx[0] < batch_size: b_idx: T.int32 = batch_idx[0] batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx] batch_tiles[0] = (batch_rows[0] + 16 - 1) // 16 if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] LH_start: T.int32 = tile_id[0] * 16 q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 32 + length_info[b_idx], 0) T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: m_smem[row] = T.float32(-50000.0) d_smem[row] = T.float32(1.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads() T.writes(O_local[i, j]) O_local[i, j] = T.float32(0.0) T.tvm_storage_sync("shared") for li_lj_fused_0 in range(8): for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for li_lj_fused_3 in T.vectorized(4): with T.block("Q_load"): i = T.axis.spatial(16, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 256) j = T.axis.spatial(256, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = q_indptr_val + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: freq = T.float32() Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 128, q[cur_L, cur_H_qo, j + 128] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 128]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 256) / T.float32(256.0))}), q[cur_L, cur_H_qo, j]) else: Q_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for iterator_1 in range((kv_chunk_len[0] + 31) // 32): L_kv_start: T.int32 = iterator_1 * 32 for lz_ly_fused_0 in range(16): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("K_load"): i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 32] page_offset: T.int32 = seq_offset % 32 K_smem[i, j] = pages[page_no, 0, by, page_offset, j] else: K_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") for lz_ly_fused_0 in range(16): for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"): for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"): for lz_ly_fused_3 in T.vectorized(4): with T.block("V_load"): i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 256) j = T.axis.spatial(256, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 256) T.reads() T.writes() cur_L: T.int32 = L_kv_start + i if cur_L < kv_chunk_len[0]: seq_offset: T.int32 = cur_L page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 32] page_offset: T.int32 = seq_offset % 32 V_smem[i, j] = pages[page_no, 1, by, page_offset, j] else: V_smem[i, j] = T.float16(0.0) T.tvm_storage_sync("shared") with T.block(""): T.reads(Q_smem[0:16, 0:256], K_smem[0:32, 0:256]) T.writes(S_local[0:16, 0:32]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(2, 2): with T.block("S_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 2 + li_1_init) j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 2 + lj_1_init) T.reads() T.writes(S_local[i, j]) S_local[i, j] = T.float32(0.0) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, li_1, lj_1, lk_1 in T.grid(32, 2, 2, 8): with T.block("S_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) k = T.axis.reduce(256, lk_0 * 8 + lk_1) T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k]) T.writes(S_local[i, j]) S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.090168440055560212) T.tvm_storage_sync("shared") for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(2, 2): with T.block("S_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 2 + li_1) j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 2 + lj_1) T.reads(S_local[i, j]) T.writes(S_smem[i, j]) S_smem[i, j] = S_local[i, j] T.tvm_storage_sync("shared") for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update1"): T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i]) T.writes(m_prev[i], m_new[i], d_new[i]) m_prev[i] = m_smem[row] m_new[i] = m_smem[row] row_: T.int32 = LH_start + row for j in range(32): if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx with T.block("update"): T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i]) T.writes(S_smem[row, 0:32]) for j in range(32): if row < 16: row_: T.int32 = LH_start + row if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i]) for i in range(1): row: T.int32 = i * 32 * 4 + ty * 32 + tx if row < 16: with T.block("update"): T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i]) T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row]) for j in range(32): d_new[i] = d_new[i] + S_smem[row, j] m_smem[row] = m_new[i] d_smem[row] = d_new[i] m_prev_smem[row] = m_prev[i] T.tvm_storage_sync("shared") with T.block(""): T.reads(m_prev_smem[0:16], m_smem[0:16], S_smem[0:16, 0:32], V_smem[0:32, 0:256]) T.writes(O_local[0:16, 0:256]) for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"): for li_1_init, lj_1_init in T.grid(4, 8): with T.block("O_gemm_init"): i = T.axis.spatial(16, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 32 * 4 + li_1_init) j = T.axis.spatial(256, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 32 * 8 + lj_1_init) T.reads() T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8): with T.block("O_gemm_update"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) k = T.axis.reduce(32, lk_0 * 8 + lk_1) T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j]) T.writes(O_local[i, j]) O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j]) for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for li_1, lj_1 in T.grid(4, 8): with T.block("O_store"): i = T.axis.spatial(16, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 32 * 4 + li_1) j = T.axis.spatial(256, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 32 * 8 + lj_1) T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i]) T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i]) for li_0 in range(1): for li_1 in T.thread_binding(4, thread="threadIdx.y"): for li_2 in T.thread_binding(32, thread="threadIdx.x"): with T.block("lse_store"): i = T.axis.spatial(16, li_0 * 128 + li_1 * 32 + li_2) T.where((li_0 * 4 + li_1) * 32 + li_2 < 16) T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i]) T.writes(lse[q_indptr[b_idx] + (LH_start + i), by]) cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) cur_H_qo: T.int32 = by if cur_L < q_indptr[b_idx + 1]: lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) tile_id[0] = tile_id[0] + 16 @R.function def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"): R.func_attr({"relax.memory_plan_dynamic_func_output": True}) gv: R.Tensor((2048, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global")) return gv @R.function def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), R.Object): batch_size = T.int64() vocab_size = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight4: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale4: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] gpt_neox_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias4: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias4: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight4: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale4: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias4: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight4: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale4: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias4: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias4: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight4: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale4: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm99: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(input_embeds, gpt_neox_layers_0_input_layernorm_weight4, gpt_neox_layers_0_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv245 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight4, gpt_neox_layers_0_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims195: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv245, axes=None) matmul195: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm99, permute_dims195, out_dtype="void") add288: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul195, gpt_neox_layers_0_attention_query_key_value_bias4) reshape192: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add288, R.shape([batch_size, 1, 24, 256])) reshape193: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape192, R.shape([batch_size, 24, 256])) lv246 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape193), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape194: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv246, R.shape([batch_size, 1, 8, 256])) reshape195: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape194, R.shape([batch_size, 1, 2048])) lv247 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight4, gpt_neox_layers_0_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims196: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv247, axes=None) matmul196: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape195, permute_dims196, out_dtype="void") add289: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul196, gpt_neox_layers_0_attention_dense_bias4) layer_norm100: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(input_embeds, gpt_neox_layers_0_post_attention_layernorm_weight4, gpt_neox_layers_0_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv248 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims197: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv248, axes=None) matmul197: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm100, permute_dims197, out_dtype="float32") add290: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul197, gpt_neox_layers_0_mlp_dense_h_to_4h_bias4) gelu48: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add290) astype99: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu48, dtype="float16") lv249 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims198: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv249, axes=None) matmul198: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype99, permute_dims198, out_dtype="float32") add291: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul198, gpt_neox_layers_0_mlp_dense_4h_to_h_bias4) astype100: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add291, dtype="float16") add292: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype100, add289) add293: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add292, input_embeds) layer_norm101: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add293, gpt_neox_layers_1_input_layernorm_weight4, gpt_neox_layers_1_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv250 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight4, gpt_neox_layers_1_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims199: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv250, axes=None) matmul199: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm101, permute_dims199, out_dtype="void") add294: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul199, gpt_neox_layers_1_attention_query_key_value_bias4) reshape196: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add294, R.shape([batch_size, 1, 24, 256])) reshape197: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape196, R.shape([batch_size, 24, 256])) lv251 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape197), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape198: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv251, R.shape([batch_size, 1, 8, 256])) reshape199: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape198, R.shape([batch_size, 1, 2048])) lv252 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight4, gpt_neox_layers_1_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims200: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv252, axes=None) matmul200: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape199, permute_dims200, out_dtype="void") add295: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul200, gpt_neox_layers_1_attention_dense_bias4) layer_norm102: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add293, gpt_neox_layers_1_post_attention_layernorm_weight4, gpt_neox_layers_1_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv253 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims201: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv253, axes=None) matmul201: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm102, permute_dims201, out_dtype="float32") add296: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul201, gpt_neox_layers_1_mlp_dense_h_to_4h_bias4) gelu49: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add296) astype101: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu49, dtype="float16") lv254 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims202: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv254, axes=None) matmul202: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype101, permute_dims202, out_dtype="float32") add297: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul202, gpt_neox_layers_1_mlp_dense_4h_to_h_bias4) astype102: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add297, dtype="float16") add298: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype102, add295) add299: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add298, add293) layer_norm103: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add299, gpt_neox_layers_2_input_layernorm_weight4, gpt_neox_layers_2_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv255 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight4, gpt_neox_layers_2_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims203: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv255, axes=None) matmul203: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm103, permute_dims203, out_dtype="void") add300: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul203, gpt_neox_layers_2_attention_query_key_value_bias4) reshape200: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add300, R.shape([batch_size, 1, 24, 256])) reshape201: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape200, R.shape([batch_size, 24, 256])) lv256 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape201), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape202: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv256, R.shape([batch_size, 1, 8, 256])) reshape203: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape202, R.shape([batch_size, 1, 2048])) lv257 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight4, gpt_neox_layers_2_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims204: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv257, axes=None) matmul204: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape203, permute_dims204, out_dtype="void") add301: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul204, gpt_neox_layers_2_attention_dense_bias4) layer_norm104: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add299, gpt_neox_layers_2_post_attention_layernorm_weight4, gpt_neox_layers_2_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv258 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims205: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv258, axes=None) matmul205: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm104, permute_dims205, out_dtype="float32") add302: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul205, gpt_neox_layers_2_mlp_dense_h_to_4h_bias4) gelu50: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add302) astype103: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu50, dtype="float16") lv259 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims206: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv259, axes=None) matmul206: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype103, permute_dims206, out_dtype="float32") add303: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul206, gpt_neox_layers_2_mlp_dense_4h_to_h_bias4) astype104: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add303, dtype="float16") add304: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype104, add301) add305: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add304, add299) layer_norm105: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add305, gpt_neox_layers_3_input_layernorm_weight4, gpt_neox_layers_3_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv260 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight4, gpt_neox_layers_3_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims207: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv260, axes=None) matmul207: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm105, permute_dims207, out_dtype="void") add306: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul207, gpt_neox_layers_3_attention_query_key_value_bias4) reshape204: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add306, R.shape([batch_size, 1, 24, 256])) reshape205: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape204, R.shape([batch_size, 24, 256])) lv261 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape205), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape206: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv261, R.shape([batch_size, 1, 8, 256])) reshape207: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape206, R.shape([batch_size, 1, 2048])) lv262 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight4, gpt_neox_layers_3_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims208: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv262, axes=None) matmul208: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape207, permute_dims208, out_dtype="void") add307: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul208, gpt_neox_layers_3_attention_dense_bias4) layer_norm106: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add305, gpt_neox_layers_3_post_attention_layernorm_weight4, gpt_neox_layers_3_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv263 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims209: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv263, axes=None) matmul209: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm106, permute_dims209, out_dtype="float32") add308: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul209, gpt_neox_layers_3_mlp_dense_h_to_4h_bias4) gelu51: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add308) astype105: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu51, dtype="float16") lv264 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims210: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv264, axes=None) matmul210: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype105, permute_dims210, out_dtype="float32") add309: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul210, gpt_neox_layers_3_mlp_dense_4h_to_h_bias4) astype106: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add309, dtype="float16") add310: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype106, add307) add311: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add310, add305) layer_norm107: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add311, gpt_neox_layers_4_input_layernorm_weight4, gpt_neox_layers_4_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv265 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight4, gpt_neox_layers_4_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims211: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv265, axes=None) matmul211: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm107, permute_dims211, out_dtype="void") add312: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul211, gpt_neox_layers_4_attention_query_key_value_bias4) reshape208: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add312, R.shape([batch_size, 1, 24, 256])) reshape209: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape208, R.shape([batch_size, 24, 256])) lv266 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape209), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape210: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv266, R.shape([batch_size, 1, 8, 256])) reshape211: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape210, R.shape([batch_size, 1, 2048])) lv267 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight4, gpt_neox_layers_4_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims212: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv267, axes=None) matmul212: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape211, permute_dims212, out_dtype="void") add313: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul212, gpt_neox_layers_4_attention_dense_bias4) layer_norm108: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add311, gpt_neox_layers_4_post_attention_layernorm_weight4, gpt_neox_layers_4_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv268 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims213: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv268, axes=None) matmul213: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm108, permute_dims213, out_dtype="float32") add314: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul213, gpt_neox_layers_4_mlp_dense_h_to_4h_bias4) gelu52: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add314) astype107: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu52, dtype="float16") lv269 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims214: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv269, axes=None) matmul214: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype107, permute_dims214, out_dtype="float32") add315: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul214, gpt_neox_layers_4_mlp_dense_4h_to_h_bias4) astype108: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add315, dtype="float16") add316: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype108, add313) add317: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add316, add311) layer_norm109: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add317, gpt_neox_layers_5_input_layernorm_weight4, gpt_neox_layers_5_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv270 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight4, gpt_neox_layers_5_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims215: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv270, axes=None) matmul215: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm109, permute_dims215, out_dtype="void") add318: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul215, gpt_neox_layers_5_attention_query_key_value_bias4) reshape212: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add318, R.shape([batch_size, 1, 24, 256])) reshape213: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape212, R.shape([batch_size, 24, 256])) lv271 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape213), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape214: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv271, R.shape([batch_size, 1, 8, 256])) reshape215: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape214, R.shape([batch_size, 1, 2048])) lv272 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight4, gpt_neox_layers_5_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims216: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv272, axes=None) matmul216: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape215, permute_dims216, out_dtype="void") add319: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul216, gpt_neox_layers_5_attention_dense_bias4) layer_norm110: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add317, gpt_neox_layers_5_post_attention_layernorm_weight4, gpt_neox_layers_5_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv273 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims217: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv273, axes=None) matmul217: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm110, permute_dims217, out_dtype="float32") add320: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul217, gpt_neox_layers_5_mlp_dense_h_to_4h_bias4) gelu53: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add320) astype109: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu53, dtype="float16") lv274 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims218: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv274, axes=None) matmul218: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype109, permute_dims218, out_dtype="float32") add321: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul218, gpt_neox_layers_5_mlp_dense_4h_to_h_bias4) astype110: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add321, dtype="float16") add322: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype110, add319) add323: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add322, add317) layer_norm111: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add323, gpt_neox_layers_6_input_layernorm_weight4, gpt_neox_layers_6_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv275 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight4, gpt_neox_layers_6_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims219: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv275, axes=None) matmul219: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm111, permute_dims219, out_dtype="void") add324: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul219, gpt_neox_layers_6_attention_query_key_value_bias4) reshape216: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add324, R.shape([batch_size, 1, 24, 256])) reshape217: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape216, R.shape([batch_size, 24, 256])) lv276 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape217), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape218: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv276, R.shape([batch_size, 1, 8, 256])) reshape219: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape218, R.shape([batch_size, 1, 2048])) lv277 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight4, gpt_neox_layers_6_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims220: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv277, axes=None) matmul220: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape219, permute_dims220, out_dtype="void") add325: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul220, gpt_neox_layers_6_attention_dense_bias4) layer_norm112: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add323, gpt_neox_layers_6_post_attention_layernorm_weight4, gpt_neox_layers_6_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv278 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims221: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv278, axes=None) matmul221: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm112, permute_dims221, out_dtype="float32") add326: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul221, gpt_neox_layers_6_mlp_dense_h_to_4h_bias4) gelu54: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add326) astype111: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu54, dtype="float16") lv279 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims222: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv279, axes=None) matmul222: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype111, permute_dims222, out_dtype="float32") add327: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul222, gpt_neox_layers_6_mlp_dense_4h_to_h_bias4) astype112: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add327, dtype="float16") add328: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype112, add325) add329: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add328, add323) layer_norm113: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add329, gpt_neox_layers_7_input_layernorm_weight4, gpt_neox_layers_7_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv280 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight4, gpt_neox_layers_7_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims223: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv280, axes=None) matmul223: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm113, permute_dims223, out_dtype="void") add330: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul223, gpt_neox_layers_7_attention_query_key_value_bias4) reshape220: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add330, R.shape([batch_size, 1, 24, 256])) reshape221: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape220, R.shape([batch_size, 24, 256])) lv281 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape221), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape222: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv281, R.shape([batch_size, 1, 8, 256])) reshape223: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape222, R.shape([batch_size, 1, 2048])) lv282 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight4, gpt_neox_layers_7_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims224: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv282, axes=None) matmul224: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape223, permute_dims224, out_dtype="void") add331: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul224, gpt_neox_layers_7_attention_dense_bias4) layer_norm114: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add329, gpt_neox_layers_7_post_attention_layernorm_weight4, gpt_neox_layers_7_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv283 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims225: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv283, axes=None) matmul225: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm114, permute_dims225, out_dtype="float32") add332: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul225, gpt_neox_layers_7_mlp_dense_h_to_4h_bias4) gelu55: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add332) astype113: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu55, dtype="float16") lv284 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims226: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv284, axes=None) matmul226: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype113, permute_dims226, out_dtype="float32") add333: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul226, gpt_neox_layers_7_mlp_dense_4h_to_h_bias4) astype114: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add333, dtype="float16") add334: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype114, add331) add335: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add334, add329) layer_norm115: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add335, gpt_neox_layers_8_input_layernorm_weight4, gpt_neox_layers_8_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv285 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight4, gpt_neox_layers_8_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims227: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv285, axes=None) matmul227: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm115, permute_dims227, out_dtype="void") add336: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul227, gpt_neox_layers_8_attention_query_key_value_bias4) reshape224: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add336, R.shape([batch_size, 1, 24, 256])) reshape225: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape224, R.shape([batch_size, 24, 256])) lv286 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape225), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape226: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv286, R.shape([batch_size, 1, 8, 256])) reshape227: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape226, R.shape([batch_size, 1, 2048])) lv287 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight4, gpt_neox_layers_8_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims228: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv287, axes=None) matmul228: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape227, permute_dims228, out_dtype="void") add337: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul228, gpt_neox_layers_8_attention_dense_bias4) layer_norm116: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add335, gpt_neox_layers_8_post_attention_layernorm_weight4, gpt_neox_layers_8_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv288 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims229: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv288, axes=None) matmul229: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm116, permute_dims229, out_dtype="float32") add338: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul229, gpt_neox_layers_8_mlp_dense_h_to_4h_bias4) gelu56: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add338) astype115: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu56, dtype="float16") lv289 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims230: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv289, axes=None) matmul230: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype115, permute_dims230, out_dtype="float32") add339: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul230, gpt_neox_layers_8_mlp_dense_4h_to_h_bias4) astype116: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add339, dtype="float16") add340: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype116, add337) add341: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add340, add335) layer_norm117: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add341, gpt_neox_layers_9_input_layernorm_weight4, gpt_neox_layers_9_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv290 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight4, gpt_neox_layers_9_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims231: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv290, axes=None) matmul231: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm117, permute_dims231, out_dtype="void") add342: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul231, gpt_neox_layers_9_attention_query_key_value_bias4) reshape228: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add342, R.shape([batch_size, 1, 24, 256])) reshape229: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape228, R.shape([batch_size, 24, 256])) lv291 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape229), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape230: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv291, R.shape([batch_size, 1, 8, 256])) reshape231: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape230, R.shape([batch_size, 1, 2048])) lv292 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight4, gpt_neox_layers_9_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims232: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv292, axes=None) matmul232: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape231, permute_dims232, out_dtype="void") add343: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul232, gpt_neox_layers_9_attention_dense_bias4) layer_norm118: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add341, gpt_neox_layers_9_post_attention_layernorm_weight4, gpt_neox_layers_9_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv293 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims233: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv293, axes=None) matmul233: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm118, permute_dims233, out_dtype="float32") add344: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul233, gpt_neox_layers_9_mlp_dense_h_to_4h_bias4) gelu57: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add344) astype117: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu57, dtype="float16") lv294 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims234: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv294, axes=None) matmul234: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype117, permute_dims234, out_dtype="float32") add345: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul234, gpt_neox_layers_9_mlp_dense_4h_to_h_bias4) astype118: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add345, dtype="float16") add346: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype118, add343) add347: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add346, add341) layer_norm119: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add347, gpt_neox_layers_10_input_layernorm_weight4, gpt_neox_layers_10_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv295 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight4, gpt_neox_layers_10_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims235: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv295, axes=None) matmul235: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm119, permute_dims235, out_dtype="void") add348: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul235, gpt_neox_layers_10_attention_query_key_value_bias4) reshape232: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add348, R.shape([batch_size, 1, 24, 256])) reshape233: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape232, R.shape([batch_size, 24, 256])) lv296 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape233), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape234: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv296, R.shape([batch_size, 1, 8, 256])) reshape235: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape234, R.shape([batch_size, 1, 2048])) lv297 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight4, gpt_neox_layers_10_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims236: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv297, axes=None) matmul236: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape235, permute_dims236, out_dtype="void") add349: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul236, gpt_neox_layers_10_attention_dense_bias4) layer_norm120: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add347, gpt_neox_layers_10_post_attention_layernorm_weight4, gpt_neox_layers_10_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv298 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims237: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv298, axes=None) matmul237: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm120, permute_dims237, out_dtype="float32") add350: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul237, gpt_neox_layers_10_mlp_dense_h_to_4h_bias4) gelu58: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add350) astype119: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu58, dtype="float16") lv299 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims238: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv299, axes=None) matmul238: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype119, permute_dims238, out_dtype="float32") add351: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul238, gpt_neox_layers_10_mlp_dense_4h_to_h_bias4) astype120: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add351, dtype="float16") add352: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype120, add349) add353: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add352, add347) layer_norm121: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add353, gpt_neox_layers_11_input_layernorm_weight4, gpt_neox_layers_11_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv300 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight4, gpt_neox_layers_11_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims239: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv300, axes=None) matmul239: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm121, permute_dims239, out_dtype="void") add354: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul239, gpt_neox_layers_11_attention_query_key_value_bias4) reshape236: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add354, R.shape([batch_size, 1, 24, 256])) reshape237: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape236, R.shape([batch_size, 24, 256])) lv301 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape237), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape238: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv301, R.shape([batch_size, 1, 8, 256])) reshape239: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape238, R.shape([batch_size, 1, 2048])) lv302 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight4, gpt_neox_layers_11_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims240: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv302, axes=None) matmul240: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape239, permute_dims240, out_dtype="void") add355: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul240, gpt_neox_layers_11_attention_dense_bias4) layer_norm122: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add353, gpt_neox_layers_11_post_attention_layernorm_weight4, gpt_neox_layers_11_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv303 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims241: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv303, axes=None) matmul241: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm122, permute_dims241, out_dtype="float32") add356: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul241, gpt_neox_layers_11_mlp_dense_h_to_4h_bias4) gelu59: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add356) astype121: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu59, dtype="float16") lv304 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims242: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv304, axes=None) matmul242: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype121, permute_dims242, out_dtype="float32") add357: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul242, gpt_neox_layers_11_mlp_dense_4h_to_h_bias4) astype122: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add357, dtype="float16") add358: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype122, add355) add359: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add358, add353) layer_norm123: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add359, gpt_neox_layers_12_input_layernorm_weight4, gpt_neox_layers_12_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv305 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight4, gpt_neox_layers_12_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims243: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv305, axes=None) matmul243: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm123, permute_dims243, out_dtype="void") add360: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul243, gpt_neox_layers_12_attention_query_key_value_bias4) reshape240: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add360, R.shape([batch_size, 1, 24, 256])) reshape241: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape240, R.shape([batch_size, 24, 256])) lv306 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape241), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape242: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv306, R.shape([batch_size, 1, 8, 256])) reshape243: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape242, R.shape([batch_size, 1, 2048])) lv307 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight4, gpt_neox_layers_12_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims244: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv307, axes=None) matmul244: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape243, permute_dims244, out_dtype="void") add361: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul244, gpt_neox_layers_12_attention_dense_bias4) layer_norm124: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add359, gpt_neox_layers_12_post_attention_layernorm_weight4, gpt_neox_layers_12_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv308 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims245: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv308, axes=None) matmul245: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm124, permute_dims245, out_dtype="float32") add362: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul245, gpt_neox_layers_12_mlp_dense_h_to_4h_bias4) gelu60: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add362) astype123: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu60, dtype="float16") lv309 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims246: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv309, axes=None) matmul246: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype123, permute_dims246, out_dtype="float32") add363: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul246, gpt_neox_layers_12_mlp_dense_4h_to_h_bias4) astype124: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add363, dtype="float16") add364: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype124, add361) add365: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add364, add359) layer_norm125: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add365, gpt_neox_layers_13_input_layernorm_weight4, gpt_neox_layers_13_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv310 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight4, gpt_neox_layers_13_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims247: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv310, axes=None) matmul247: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm125, permute_dims247, out_dtype="void") add366: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul247, gpt_neox_layers_13_attention_query_key_value_bias4) reshape244: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add366, R.shape([batch_size, 1, 24, 256])) reshape245: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape244, R.shape([batch_size, 24, 256])) lv311 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape245), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape246: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv311, R.shape([batch_size, 1, 8, 256])) reshape247: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape246, R.shape([batch_size, 1, 2048])) lv312 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight4, gpt_neox_layers_13_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims248: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv312, axes=None) matmul248: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape247, permute_dims248, out_dtype="void") add367: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul248, gpt_neox_layers_13_attention_dense_bias4) layer_norm126: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add365, gpt_neox_layers_13_post_attention_layernorm_weight4, gpt_neox_layers_13_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv313 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims249: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv313, axes=None) matmul249: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm126, permute_dims249, out_dtype="float32") add368: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul249, gpt_neox_layers_13_mlp_dense_h_to_4h_bias4) gelu61: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add368) astype125: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu61, dtype="float16") lv314 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims250: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv314, axes=None) matmul250: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype125, permute_dims250, out_dtype="float32") add369: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul250, gpt_neox_layers_13_mlp_dense_4h_to_h_bias4) astype126: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add369, dtype="float16") add370: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype126, add367) add371: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add370, add365) layer_norm127: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add371, gpt_neox_layers_14_input_layernorm_weight4, gpt_neox_layers_14_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv315 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight4, gpt_neox_layers_14_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims251: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv315, axes=None) matmul251: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm127, permute_dims251, out_dtype="void") add372: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul251, gpt_neox_layers_14_attention_query_key_value_bias4) reshape248: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add372, R.shape([batch_size, 1, 24, 256])) reshape249: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape248, R.shape([batch_size, 24, 256])) lv316 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape249), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape250: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv316, R.shape([batch_size, 1, 8, 256])) reshape251: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape250, R.shape([batch_size, 1, 2048])) lv317 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight4, gpt_neox_layers_14_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims252: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv317, axes=None) matmul252: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape251, permute_dims252, out_dtype="void") add373: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul252, gpt_neox_layers_14_attention_dense_bias4) layer_norm128: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add371, gpt_neox_layers_14_post_attention_layernorm_weight4, gpt_neox_layers_14_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv318 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims253: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv318, axes=None) matmul253: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm128, permute_dims253, out_dtype="float32") add374: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul253, gpt_neox_layers_14_mlp_dense_h_to_4h_bias4) gelu62: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add374) astype127: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu62, dtype="float16") lv319 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims254: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv319, axes=None) matmul254: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype127, permute_dims254, out_dtype="float32") add375: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul254, gpt_neox_layers_14_mlp_dense_4h_to_h_bias4) astype128: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add375, dtype="float16") add376: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype128, add373) add377: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add376, add371) layer_norm129: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add377, gpt_neox_layers_15_input_layernorm_weight4, gpt_neox_layers_15_input_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv320 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight4, gpt_neox_layers_15_attention_query_key_value_q_scale4), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims255: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv320, axes=None) matmul255: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.matmul(layer_norm129, permute_dims255, out_dtype="void") add378: R.Tensor((batch_size, 1, 6144), dtype="float16") = R.add(matmul255, gpt_neox_layers_15_attention_query_key_value_bias4) reshape252: R.Tensor((batch_size, 1, 24, 256), dtype="float16") = R.reshape(add378, R.shape([batch_size, 1, 24, 256])) reshape253: R.Tensor((batch_size, 24, 256), dtype="float16") = R.reshape(reshape252, R.shape([batch_size, 24, 256])) lv321 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape253), out_sinfo=R.Tensor((batch_size, 8, 256), dtype="float16")) reshape254: R.Tensor((batch_size, 1, 8, 256), dtype="float16") = R.reshape(lv321, R.shape([batch_size, 1, 8, 256])) reshape255: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape254, R.shape([batch_size, 1, 2048])) lv322 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight4, gpt_neox_layers_15_attention_dense_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims256: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv322, axes=None) matmul256: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape255, permute_dims256, out_dtype="void") add379: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul256, gpt_neox_layers_15_attention_dense_bias4) layer_norm130: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add377, gpt_neox_layers_15_post_attention_layernorm_weight4, gpt_neox_layers_15_post_attention_layernorm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv323 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight4, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale4), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims257: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv323, axes=None) matmul257: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.matmul(layer_norm130, permute_dims257, out_dtype="float32") add380: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.add(matmul257, gpt_neox_layers_15_mlp_dense_h_to_4h_bias4) gelu63: R.Tensor((batch_size, 1, 8192), dtype="float32") = R.nn.gelu(add380) astype129: R.Tensor((batch_size, 1, 8192), dtype="float16") = R.astype(gelu63, dtype="float16") lv324 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight4, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale4), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims258: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv324, axes=None) matmul258: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.matmul(astype129, permute_dims258, out_dtype="float32") add381: R.Tensor((batch_size, 1, 2048), dtype="float32") = R.add(matmul258, gpt_neox_layers_15_mlp_dense_4h_to_h_bias4) astype130: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.astype(add381, dtype="float16") add382: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(astype130, add379) add383: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(add382, add377) layer_norm131: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.layer_norm(add383, gpt_neox_final_layer_norm_weight4, gpt_neox_final_layer_norm_bias4, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv325 = R.call_tir(cls.dequantize, (embed_out_q_weight4, embed_out_q_scale4), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) permute_dims259: R.Tensor((2048, vocab_size), dtype="float16") = R.permute_dims(lv325, axes=None) matmul259: R.Tensor((batch_size, 1, vocab_size), dtype="float16") = R.matmul(layer_norm131, permute_dims259, out_dtype="void") astype131: R.Tensor((batch_size, 1, vocab_size), dtype="float32") = R.astype(matmul259, dtype="float32") gv4: R.Tuple(R.Tensor((batch_size, 1, vocab_size), dtype="float32"), R.Object) = astype131, paged_kv_cache R.output(gv4) return gv4 @R.function def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", "vocab_size"), dtype="float32"), R.Object): batch_size = T.int64() vocab_size = T.int64() seq_len = T.int64() R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight3: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale3: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] gpt_neox_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias3: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias3: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight3: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale3: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias3: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight3: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale3: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias3: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias3: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight3: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale3: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm66: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(input_embeds, gpt_neox_layers_0_input_layernorm_weight3, gpt_neox_layers_0_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv164 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight3, gpt_neox_layers_0_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims130: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv164, axes=None) matmul130: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm66, permute_dims130, out_dtype="void") add192: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul130, gpt_neox_layers_0_attention_query_key_value_bias3) reshape128: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add192, R.shape([1, seq_len, 24, 256])) reshape129: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape128, R.shape([seq_len, 24, 256])) lv165 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape129), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape130: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv165, R.shape([1, seq_len, 8, 256])) reshape131: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape130, R.shape([1, seq_len, 2048])) lv166 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight3, gpt_neox_layers_0_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims131: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv166, axes=None) matmul131: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape131, permute_dims131, out_dtype="void") add193: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul131, gpt_neox_layers_0_attention_dense_bias3) layer_norm67: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(input_embeds, gpt_neox_layers_0_post_attention_layernorm_weight3, gpt_neox_layers_0_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv167 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims132: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv167, axes=None) matmul132: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm67, permute_dims132, out_dtype="float32") add194: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul132, gpt_neox_layers_0_mlp_dense_h_to_4h_bias3) gelu32: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add194) astype66: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu32, dtype="float16") lv168 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims133: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv168, axes=None) matmul133: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype66, permute_dims133, out_dtype="float32") add195: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul133, gpt_neox_layers_0_mlp_dense_4h_to_h_bias3) astype67: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add195, dtype="float16") add196: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype67, add193) add197: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add196, input_embeds) layer_norm68: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add197, gpt_neox_layers_1_input_layernorm_weight3, gpt_neox_layers_1_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv169 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight3, gpt_neox_layers_1_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims134: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv169, axes=None) matmul134: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm68, permute_dims134, out_dtype="void") add198: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul134, gpt_neox_layers_1_attention_query_key_value_bias3) reshape132: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add198, R.shape([1, seq_len, 24, 256])) reshape133: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape132, R.shape([seq_len, 24, 256])) lv170 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape133), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape134: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv170, R.shape([1, seq_len, 8, 256])) reshape135: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape134, R.shape([1, seq_len, 2048])) lv171 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight3, gpt_neox_layers_1_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims135: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv171, axes=None) matmul135: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape135, permute_dims135, out_dtype="void") add199: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul135, gpt_neox_layers_1_attention_dense_bias3) layer_norm69: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add197, gpt_neox_layers_1_post_attention_layernorm_weight3, gpt_neox_layers_1_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv172 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims136: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv172, axes=None) matmul136: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm69, permute_dims136, out_dtype="float32") add200: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul136, gpt_neox_layers_1_mlp_dense_h_to_4h_bias3) gelu33: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add200) astype68: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu33, dtype="float16") lv173 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims137: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv173, axes=None) matmul137: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype68, permute_dims137, out_dtype="float32") add201: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul137, gpt_neox_layers_1_mlp_dense_4h_to_h_bias3) astype69: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add201, dtype="float16") add202: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype69, add199) add203: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add202, add197) layer_norm70: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add203, gpt_neox_layers_2_input_layernorm_weight3, gpt_neox_layers_2_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv174 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight3, gpt_neox_layers_2_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims138: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv174, axes=None) matmul138: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm70, permute_dims138, out_dtype="void") add204: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul138, gpt_neox_layers_2_attention_query_key_value_bias3) reshape136: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add204, R.shape([1, seq_len, 24, 256])) reshape137: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape136, R.shape([seq_len, 24, 256])) lv175 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape137), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape138: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv175, R.shape([1, seq_len, 8, 256])) reshape139: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape138, R.shape([1, seq_len, 2048])) lv176 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight3, gpt_neox_layers_2_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims139: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv176, axes=None) matmul139: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape139, permute_dims139, out_dtype="void") add205: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul139, gpt_neox_layers_2_attention_dense_bias3) layer_norm71: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add203, gpt_neox_layers_2_post_attention_layernorm_weight3, gpt_neox_layers_2_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv177 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims140: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv177, axes=None) matmul140: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm71, permute_dims140, out_dtype="float32") add206: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul140, gpt_neox_layers_2_mlp_dense_h_to_4h_bias3) gelu34: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add206) astype70: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu34, dtype="float16") lv178 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims141: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv178, axes=None) matmul141: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype70, permute_dims141, out_dtype="float32") add207: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul141, gpt_neox_layers_2_mlp_dense_4h_to_h_bias3) astype71: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add207, dtype="float16") add208: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype71, add205) add209: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add208, add203) layer_norm72: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add209, gpt_neox_layers_3_input_layernorm_weight3, gpt_neox_layers_3_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv179 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight3, gpt_neox_layers_3_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims142: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv179, axes=None) matmul142: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm72, permute_dims142, out_dtype="void") add210: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul142, gpt_neox_layers_3_attention_query_key_value_bias3) reshape140: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add210, R.shape([1, seq_len, 24, 256])) reshape141: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape140, R.shape([seq_len, 24, 256])) lv180 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape141), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape142: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv180, R.shape([1, seq_len, 8, 256])) reshape143: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape142, R.shape([1, seq_len, 2048])) lv181 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight3, gpt_neox_layers_3_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims143: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv181, axes=None) matmul143: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape143, permute_dims143, out_dtype="void") add211: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul143, gpt_neox_layers_3_attention_dense_bias3) layer_norm73: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add209, gpt_neox_layers_3_post_attention_layernorm_weight3, gpt_neox_layers_3_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv182 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims144: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv182, axes=None) matmul144: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm73, permute_dims144, out_dtype="float32") add212: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul144, gpt_neox_layers_3_mlp_dense_h_to_4h_bias3) gelu35: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add212) astype72: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu35, dtype="float16") lv183 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims145: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv183, axes=None) matmul145: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype72, permute_dims145, out_dtype="float32") add213: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul145, gpt_neox_layers_3_mlp_dense_4h_to_h_bias3) astype73: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add213, dtype="float16") add214: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype73, add211) add215: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add214, add209) layer_norm74: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add215, gpt_neox_layers_4_input_layernorm_weight3, gpt_neox_layers_4_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv184 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight3, gpt_neox_layers_4_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims146: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv184, axes=None) matmul146: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm74, permute_dims146, out_dtype="void") add216: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul146, gpt_neox_layers_4_attention_query_key_value_bias3) reshape144: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add216, R.shape([1, seq_len, 24, 256])) reshape145: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape144, R.shape([seq_len, 24, 256])) lv185 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape145), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape146: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv185, R.shape([1, seq_len, 8, 256])) reshape147: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape146, R.shape([1, seq_len, 2048])) lv186 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight3, gpt_neox_layers_4_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims147: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv186, axes=None) matmul147: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape147, permute_dims147, out_dtype="void") add217: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul147, gpt_neox_layers_4_attention_dense_bias3) layer_norm75: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add215, gpt_neox_layers_4_post_attention_layernorm_weight3, gpt_neox_layers_4_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv187 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims148: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv187, axes=None) matmul148: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm75, permute_dims148, out_dtype="float32") add218: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul148, gpt_neox_layers_4_mlp_dense_h_to_4h_bias3) gelu36: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add218) astype74: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu36, dtype="float16") lv188 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims149: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv188, axes=None) matmul149: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype74, permute_dims149, out_dtype="float32") add219: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul149, gpt_neox_layers_4_mlp_dense_4h_to_h_bias3) astype75: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add219, dtype="float16") add220: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype75, add217) add221: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add220, add215) layer_norm76: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add221, gpt_neox_layers_5_input_layernorm_weight3, gpt_neox_layers_5_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv189 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight3, gpt_neox_layers_5_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims150: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv189, axes=None) matmul150: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm76, permute_dims150, out_dtype="void") add222: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul150, gpt_neox_layers_5_attention_query_key_value_bias3) reshape148: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add222, R.shape([1, seq_len, 24, 256])) reshape149: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape148, R.shape([seq_len, 24, 256])) lv190 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape149), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape150: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv190, R.shape([1, seq_len, 8, 256])) reshape151: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape150, R.shape([1, seq_len, 2048])) lv191 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight3, gpt_neox_layers_5_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims151: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv191, axes=None) matmul151: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape151, permute_dims151, out_dtype="void") add223: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul151, gpt_neox_layers_5_attention_dense_bias3) layer_norm77: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add221, gpt_neox_layers_5_post_attention_layernorm_weight3, gpt_neox_layers_5_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv192 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims152: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv192, axes=None) matmul152: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm77, permute_dims152, out_dtype="float32") add224: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul152, gpt_neox_layers_5_mlp_dense_h_to_4h_bias3) gelu37: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add224) astype76: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu37, dtype="float16") lv193 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims153: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv193, axes=None) matmul153: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype76, permute_dims153, out_dtype="float32") add225: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul153, gpt_neox_layers_5_mlp_dense_4h_to_h_bias3) astype77: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add225, dtype="float16") add226: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype77, add223) add227: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add226, add221) layer_norm78: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add227, gpt_neox_layers_6_input_layernorm_weight3, gpt_neox_layers_6_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv194 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight3, gpt_neox_layers_6_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims154: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv194, axes=None) matmul154: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm78, permute_dims154, out_dtype="void") add228: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul154, gpt_neox_layers_6_attention_query_key_value_bias3) reshape152: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add228, R.shape([1, seq_len, 24, 256])) reshape153: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape152, R.shape([seq_len, 24, 256])) lv195 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape153), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape154: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv195, R.shape([1, seq_len, 8, 256])) reshape155: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape154, R.shape([1, seq_len, 2048])) lv196 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight3, gpt_neox_layers_6_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims155: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv196, axes=None) matmul155: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape155, permute_dims155, out_dtype="void") add229: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul155, gpt_neox_layers_6_attention_dense_bias3) layer_norm79: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add227, gpt_neox_layers_6_post_attention_layernorm_weight3, gpt_neox_layers_6_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv197 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims156: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv197, axes=None) matmul156: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm79, permute_dims156, out_dtype="float32") add230: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul156, gpt_neox_layers_6_mlp_dense_h_to_4h_bias3) gelu38: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add230) astype78: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu38, dtype="float16") lv198 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims157: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv198, axes=None) matmul157: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype78, permute_dims157, out_dtype="float32") add231: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul157, gpt_neox_layers_6_mlp_dense_4h_to_h_bias3) astype79: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add231, dtype="float16") add232: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype79, add229) add233: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add232, add227) layer_norm80: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add233, gpt_neox_layers_7_input_layernorm_weight3, gpt_neox_layers_7_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv199 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight3, gpt_neox_layers_7_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims158: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv199, axes=None) matmul158: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm80, permute_dims158, out_dtype="void") add234: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul158, gpt_neox_layers_7_attention_query_key_value_bias3) reshape156: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add234, R.shape([1, seq_len, 24, 256])) reshape157: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape156, R.shape([seq_len, 24, 256])) lv200 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape157), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape158: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv200, R.shape([1, seq_len, 8, 256])) reshape159: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape158, R.shape([1, seq_len, 2048])) lv201 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight3, gpt_neox_layers_7_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims159: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv201, axes=None) matmul159: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape159, permute_dims159, out_dtype="void") add235: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul159, gpt_neox_layers_7_attention_dense_bias3) layer_norm81: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add233, gpt_neox_layers_7_post_attention_layernorm_weight3, gpt_neox_layers_7_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv202 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims160: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv202, axes=None) matmul160: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm81, permute_dims160, out_dtype="float32") add236: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul160, gpt_neox_layers_7_mlp_dense_h_to_4h_bias3) gelu39: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add236) astype80: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu39, dtype="float16") lv203 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims161: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv203, axes=None) matmul161: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype80, permute_dims161, out_dtype="float32") add237: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul161, gpt_neox_layers_7_mlp_dense_4h_to_h_bias3) astype81: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add237, dtype="float16") add238: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype81, add235) add239: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add238, add233) layer_norm82: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add239, gpt_neox_layers_8_input_layernorm_weight3, gpt_neox_layers_8_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv204 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight3, gpt_neox_layers_8_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims162: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv204, axes=None) matmul162: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm82, permute_dims162, out_dtype="void") add240: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul162, gpt_neox_layers_8_attention_query_key_value_bias3) reshape160: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add240, R.shape([1, seq_len, 24, 256])) reshape161: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape160, R.shape([seq_len, 24, 256])) lv205 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape161), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape162: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv205, R.shape([1, seq_len, 8, 256])) reshape163: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape162, R.shape([1, seq_len, 2048])) lv206 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight3, gpt_neox_layers_8_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims163: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv206, axes=None) matmul163: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape163, permute_dims163, out_dtype="void") add241: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul163, gpt_neox_layers_8_attention_dense_bias3) layer_norm83: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add239, gpt_neox_layers_8_post_attention_layernorm_weight3, gpt_neox_layers_8_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv207 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims164: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv207, axes=None) matmul164: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm83, permute_dims164, out_dtype="float32") add242: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul164, gpt_neox_layers_8_mlp_dense_h_to_4h_bias3) gelu40: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add242) astype82: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu40, dtype="float16") lv208 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims165: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv208, axes=None) matmul165: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype82, permute_dims165, out_dtype="float32") add243: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul165, gpt_neox_layers_8_mlp_dense_4h_to_h_bias3) astype83: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add243, dtype="float16") add244: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype83, add241) add245: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add244, add239) layer_norm84: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add245, gpt_neox_layers_9_input_layernorm_weight3, gpt_neox_layers_9_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv209 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight3, gpt_neox_layers_9_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims166: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv209, axes=None) matmul166: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm84, permute_dims166, out_dtype="void") add246: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul166, gpt_neox_layers_9_attention_query_key_value_bias3) reshape164: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add246, R.shape([1, seq_len, 24, 256])) reshape165: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape164, R.shape([seq_len, 24, 256])) lv210 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape165), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape166: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv210, R.shape([1, seq_len, 8, 256])) reshape167: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape166, R.shape([1, seq_len, 2048])) lv211 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight3, gpt_neox_layers_9_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims167: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv211, axes=None) matmul167: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape167, permute_dims167, out_dtype="void") add247: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul167, gpt_neox_layers_9_attention_dense_bias3) layer_norm85: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add245, gpt_neox_layers_9_post_attention_layernorm_weight3, gpt_neox_layers_9_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv212 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims168: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv212, axes=None) matmul168: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm85, permute_dims168, out_dtype="float32") add248: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul168, gpt_neox_layers_9_mlp_dense_h_to_4h_bias3) gelu41: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add248) astype84: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu41, dtype="float16") lv213 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims169: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv213, axes=None) matmul169: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype84, permute_dims169, out_dtype="float32") add249: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul169, gpt_neox_layers_9_mlp_dense_4h_to_h_bias3) astype85: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add249, dtype="float16") add250: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype85, add247) add251: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add250, add245) layer_norm86: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add251, gpt_neox_layers_10_input_layernorm_weight3, gpt_neox_layers_10_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv214 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight3, gpt_neox_layers_10_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims170: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv214, axes=None) matmul170: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm86, permute_dims170, out_dtype="void") add252: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul170, gpt_neox_layers_10_attention_query_key_value_bias3) reshape168: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add252, R.shape([1, seq_len, 24, 256])) reshape169: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape168, R.shape([seq_len, 24, 256])) lv215 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape169), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape170: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv215, R.shape([1, seq_len, 8, 256])) reshape171: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape170, R.shape([1, seq_len, 2048])) lv216 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight3, gpt_neox_layers_10_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims171: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv216, axes=None) matmul171: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape171, permute_dims171, out_dtype="void") add253: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul171, gpt_neox_layers_10_attention_dense_bias3) layer_norm87: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add251, gpt_neox_layers_10_post_attention_layernorm_weight3, gpt_neox_layers_10_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv217 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims172: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv217, axes=None) matmul172: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm87, permute_dims172, out_dtype="float32") add254: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul172, gpt_neox_layers_10_mlp_dense_h_to_4h_bias3) gelu42: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add254) astype86: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu42, dtype="float16") lv218 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims173: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv218, axes=None) matmul173: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype86, permute_dims173, out_dtype="float32") add255: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul173, gpt_neox_layers_10_mlp_dense_4h_to_h_bias3) astype87: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add255, dtype="float16") add256: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype87, add253) add257: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add256, add251) layer_norm88: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add257, gpt_neox_layers_11_input_layernorm_weight3, gpt_neox_layers_11_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv219 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight3, gpt_neox_layers_11_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims174: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv219, axes=None) matmul174: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm88, permute_dims174, out_dtype="void") add258: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul174, gpt_neox_layers_11_attention_query_key_value_bias3) reshape172: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add258, R.shape([1, seq_len, 24, 256])) reshape173: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape172, R.shape([seq_len, 24, 256])) lv220 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape173), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape174: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv220, R.shape([1, seq_len, 8, 256])) reshape175: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape174, R.shape([1, seq_len, 2048])) lv221 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight3, gpt_neox_layers_11_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims175: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv221, axes=None) matmul175: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape175, permute_dims175, out_dtype="void") add259: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul175, gpt_neox_layers_11_attention_dense_bias3) layer_norm89: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add257, gpt_neox_layers_11_post_attention_layernorm_weight3, gpt_neox_layers_11_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv222 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims176: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv222, axes=None) matmul176: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm89, permute_dims176, out_dtype="float32") add260: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul176, gpt_neox_layers_11_mlp_dense_h_to_4h_bias3) gelu43: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add260) astype88: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu43, dtype="float16") lv223 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims177: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv223, axes=None) matmul177: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype88, permute_dims177, out_dtype="float32") add261: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul177, gpt_neox_layers_11_mlp_dense_4h_to_h_bias3) astype89: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add261, dtype="float16") add262: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype89, add259) add263: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add262, add257) layer_norm90: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add263, gpt_neox_layers_12_input_layernorm_weight3, gpt_neox_layers_12_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv224 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight3, gpt_neox_layers_12_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims178: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv224, axes=None) matmul178: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm90, permute_dims178, out_dtype="void") add264: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul178, gpt_neox_layers_12_attention_query_key_value_bias3) reshape176: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add264, R.shape([1, seq_len, 24, 256])) reshape177: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape176, R.shape([seq_len, 24, 256])) lv225 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape177), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape178: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv225, R.shape([1, seq_len, 8, 256])) reshape179: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape178, R.shape([1, seq_len, 2048])) lv226 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight3, gpt_neox_layers_12_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims179: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv226, axes=None) matmul179: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape179, permute_dims179, out_dtype="void") add265: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul179, gpt_neox_layers_12_attention_dense_bias3) layer_norm91: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add263, gpt_neox_layers_12_post_attention_layernorm_weight3, gpt_neox_layers_12_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv227 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims180: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv227, axes=None) matmul180: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm91, permute_dims180, out_dtype="float32") add266: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul180, gpt_neox_layers_12_mlp_dense_h_to_4h_bias3) gelu44: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add266) astype90: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu44, dtype="float16") lv228 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims181: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv228, axes=None) matmul181: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype90, permute_dims181, out_dtype="float32") add267: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul181, gpt_neox_layers_12_mlp_dense_4h_to_h_bias3) astype91: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add267, dtype="float16") add268: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype91, add265) add269: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add268, add263) layer_norm92: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add269, gpt_neox_layers_13_input_layernorm_weight3, gpt_neox_layers_13_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv229 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight3, gpt_neox_layers_13_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims182: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv229, axes=None) matmul182: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm92, permute_dims182, out_dtype="void") add270: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul182, gpt_neox_layers_13_attention_query_key_value_bias3) reshape180: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add270, R.shape([1, seq_len, 24, 256])) reshape181: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape180, R.shape([seq_len, 24, 256])) lv230 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape181), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape182: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv230, R.shape([1, seq_len, 8, 256])) reshape183: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape182, R.shape([1, seq_len, 2048])) lv231 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight3, gpt_neox_layers_13_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims183: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv231, axes=None) matmul183: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape183, permute_dims183, out_dtype="void") add271: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul183, gpt_neox_layers_13_attention_dense_bias3) layer_norm93: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add269, gpt_neox_layers_13_post_attention_layernorm_weight3, gpt_neox_layers_13_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv232 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims184: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv232, axes=None) matmul184: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm93, permute_dims184, out_dtype="float32") add272: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul184, gpt_neox_layers_13_mlp_dense_h_to_4h_bias3) gelu45: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add272) astype92: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu45, dtype="float16") lv233 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims185: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv233, axes=None) matmul185: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype92, permute_dims185, out_dtype="float32") add273: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul185, gpt_neox_layers_13_mlp_dense_4h_to_h_bias3) astype93: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add273, dtype="float16") add274: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype93, add271) add275: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add274, add269) layer_norm94: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add275, gpt_neox_layers_14_input_layernorm_weight3, gpt_neox_layers_14_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv234 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight3, gpt_neox_layers_14_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims186: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv234, axes=None) matmul186: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm94, permute_dims186, out_dtype="void") add276: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul186, gpt_neox_layers_14_attention_query_key_value_bias3) reshape184: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add276, R.shape([1, seq_len, 24, 256])) reshape185: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape184, R.shape([seq_len, 24, 256])) lv235 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape185), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape186: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv235, R.shape([1, seq_len, 8, 256])) reshape187: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape186, R.shape([1, seq_len, 2048])) lv236 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight3, gpt_neox_layers_14_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims187: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv236, axes=None) matmul187: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape187, permute_dims187, out_dtype="void") add277: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul187, gpt_neox_layers_14_attention_dense_bias3) layer_norm95: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add275, gpt_neox_layers_14_post_attention_layernorm_weight3, gpt_neox_layers_14_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv237 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims188: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv237, axes=None) matmul188: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm95, permute_dims188, out_dtype="float32") add278: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul188, gpt_neox_layers_14_mlp_dense_h_to_4h_bias3) gelu46: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add278) astype94: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu46, dtype="float16") lv238 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims189: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv238, axes=None) matmul189: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype94, permute_dims189, out_dtype="float32") add279: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul189, gpt_neox_layers_14_mlp_dense_4h_to_h_bias3) astype95: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add279, dtype="float16") add280: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype95, add277) add281: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add280, add275) layer_norm96: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add281, gpt_neox_layers_15_input_layernorm_weight3, gpt_neox_layers_15_input_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv239 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight3, gpt_neox_layers_15_attention_query_key_value_q_scale3), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims190: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv239, axes=None) matmul190: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm96, permute_dims190, out_dtype="void") add282: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul190, gpt_neox_layers_15_attention_query_key_value_bias3) reshape188: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add282, R.shape([1, seq_len, 24, 256])) reshape189: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape188, R.shape([seq_len, 24, 256])) lv240 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape189), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape190: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv240, R.shape([1, seq_len, 8, 256])) reshape191: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape190, R.shape([1, seq_len, 2048])) lv241 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight3, gpt_neox_layers_15_attention_dense_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims191: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv241, axes=None) matmul191: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape191, permute_dims191, out_dtype="void") add283: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul191, gpt_neox_layers_15_attention_dense_bias3) layer_norm97: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add281, gpt_neox_layers_15_post_attention_layernorm_weight3, gpt_neox_layers_15_post_attention_layernorm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv242 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight3, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale3), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims192: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv242, axes=None) matmul192: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm97, permute_dims192, out_dtype="float32") add284: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul192, gpt_neox_layers_15_mlp_dense_h_to_4h_bias3) gelu47: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add284) astype96: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu47, dtype="float16") lv243 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight3, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale3), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims193: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv243, axes=None) matmul193: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype96, permute_dims193, out_dtype="float32") add285: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul193, gpt_neox_layers_15_mlp_dense_4h_to_h_bias3) astype97: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add285, dtype="float16") add286: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype97, add283) add287: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add286, add281) layer_norm98: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add287, gpt_neox_final_layer_norm_weight3, gpt_neox_final_layer_norm_bias3, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) take1: R.Tensor((1, batch_size, 2048), dtype="float16") = R.take(layer_norm98, logit_positions, axis=1) lv244 = R.call_tir(cls.dequantize, (embed_out_q_weight3, embed_out_q_scale3), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) permute_dims194: R.Tensor((2048, vocab_size), dtype="float16") = R.permute_dims(lv244, axes=None) matmul194: R.Tensor((1, batch_size, vocab_size), dtype="float16") = R.matmul(take1, permute_dims194, out_dtype="void") astype98: R.Tensor((1, batch_size, vocab_size), dtype="float32") = R.astype(matmul194, dtype="float32") gv3: R.Tuple(R.Tensor((1, batch_size, vocab_size), dtype="float32"), R.Object) = astype98, paged_kv_cache R.output(gv3) return gv3 @R.function def batch_verify(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", "vocab_size"), dtype="float32"), R.Object): seq_len = T.int64() vocab_size = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight5: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale5: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] gpt_neox_layers_0_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias5: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias5: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight5: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale5: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias5: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight5: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale5: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias5: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias5: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight5: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale5: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm132: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(input_embeds, gpt_neox_layers_0_input_layernorm_weight5, gpt_neox_layers_0_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv326 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight5, gpt_neox_layers_0_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims260: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv326, axes=None) matmul260: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm132, permute_dims260, out_dtype="void") add384: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul260, gpt_neox_layers_0_attention_query_key_value_bias5) reshape256: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add384, R.shape([1, seq_len, 24, 256])) reshape257: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape256, R.shape([seq_len, 24, 256])) lv327 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape257), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape258: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv327, R.shape([1, seq_len, 8, 256])) reshape259: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape258, R.shape([1, seq_len, 2048])) lv328 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight5, gpt_neox_layers_0_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims261: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv328, axes=None) matmul261: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape259, permute_dims261, out_dtype="void") add385: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul261, gpt_neox_layers_0_attention_dense_bias5) layer_norm133: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(input_embeds, gpt_neox_layers_0_post_attention_layernorm_weight5, gpt_neox_layers_0_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv329 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims262: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv329, axes=None) matmul262: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm133, permute_dims262, out_dtype="float32") add386: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul262, gpt_neox_layers_0_mlp_dense_h_to_4h_bias5) gelu64: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add386) astype132: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu64, dtype="float16") lv330 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims263: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv330, axes=None) matmul263: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype132, permute_dims263, out_dtype="float32") add387: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul263, gpt_neox_layers_0_mlp_dense_4h_to_h_bias5) astype133: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add387, dtype="float16") add388: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype133, add385) add389: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add388, input_embeds) layer_norm134: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add389, gpt_neox_layers_1_input_layernorm_weight5, gpt_neox_layers_1_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv331 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight5, gpt_neox_layers_1_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims264: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv331, axes=None) matmul264: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm134, permute_dims264, out_dtype="void") add390: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul264, gpt_neox_layers_1_attention_query_key_value_bias5) reshape260: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add390, R.shape([1, seq_len, 24, 256])) reshape261: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape260, R.shape([seq_len, 24, 256])) lv332 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape261), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape262: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv332, R.shape([1, seq_len, 8, 256])) reshape263: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape262, R.shape([1, seq_len, 2048])) lv333 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight5, gpt_neox_layers_1_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims265: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv333, axes=None) matmul265: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape263, permute_dims265, out_dtype="void") add391: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul265, gpt_neox_layers_1_attention_dense_bias5) layer_norm135: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add389, gpt_neox_layers_1_post_attention_layernorm_weight5, gpt_neox_layers_1_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv334 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims266: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv334, axes=None) matmul266: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm135, permute_dims266, out_dtype="float32") add392: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul266, gpt_neox_layers_1_mlp_dense_h_to_4h_bias5) gelu65: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add392) astype134: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu65, dtype="float16") lv335 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims267: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv335, axes=None) matmul267: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype134, permute_dims267, out_dtype="float32") add393: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul267, gpt_neox_layers_1_mlp_dense_4h_to_h_bias5) astype135: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add393, dtype="float16") add394: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype135, add391) add395: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add394, add389) layer_norm136: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add395, gpt_neox_layers_2_input_layernorm_weight5, gpt_neox_layers_2_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv336 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight5, gpt_neox_layers_2_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims268: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv336, axes=None) matmul268: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm136, permute_dims268, out_dtype="void") add396: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul268, gpt_neox_layers_2_attention_query_key_value_bias5) reshape264: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add396, R.shape([1, seq_len, 24, 256])) reshape265: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape264, R.shape([seq_len, 24, 256])) lv337 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape265), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape266: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv337, R.shape([1, seq_len, 8, 256])) reshape267: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape266, R.shape([1, seq_len, 2048])) lv338 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight5, gpt_neox_layers_2_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims269: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv338, axes=None) matmul269: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape267, permute_dims269, out_dtype="void") add397: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul269, gpt_neox_layers_2_attention_dense_bias5) layer_norm137: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add395, gpt_neox_layers_2_post_attention_layernorm_weight5, gpt_neox_layers_2_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv339 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims270: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv339, axes=None) matmul270: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm137, permute_dims270, out_dtype="float32") add398: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul270, gpt_neox_layers_2_mlp_dense_h_to_4h_bias5) gelu66: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add398) astype136: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu66, dtype="float16") lv340 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims271: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv340, axes=None) matmul271: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype136, permute_dims271, out_dtype="float32") add399: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul271, gpt_neox_layers_2_mlp_dense_4h_to_h_bias5) astype137: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add399, dtype="float16") add400: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype137, add397) add401: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add400, add395) layer_norm138: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add401, gpt_neox_layers_3_input_layernorm_weight5, gpt_neox_layers_3_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv341 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight5, gpt_neox_layers_3_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims272: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv341, axes=None) matmul272: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm138, permute_dims272, out_dtype="void") add402: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul272, gpt_neox_layers_3_attention_query_key_value_bias5) reshape268: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add402, R.shape([1, seq_len, 24, 256])) reshape269: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape268, R.shape([seq_len, 24, 256])) lv342 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape269), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape270: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv342, R.shape([1, seq_len, 8, 256])) reshape271: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape270, R.shape([1, seq_len, 2048])) lv343 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight5, gpt_neox_layers_3_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims273: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv343, axes=None) matmul273: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape271, permute_dims273, out_dtype="void") add403: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul273, gpt_neox_layers_3_attention_dense_bias5) layer_norm139: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add401, gpt_neox_layers_3_post_attention_layernorm_weight5, gpt_neox_layers_3_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv344 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims274: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv344, axes=None) matmul274: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm139, permute_dims274, out_dtype="float32") add404: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul274, gpt_neox_layers_3_mlp_dense_h_to_4h_bias5) gelu67: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add404) astype138: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu67, dtype="float16") lv345 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims275: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv345, axes=None) matmul275: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype138, permute_dims275, out_dtype="float32") add405: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul275, gpt_neox_layers_3_mlp_dense_4h_to_h_bias5) astype139: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add405, dtype="float16") add406: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype139, add403) add407: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add406, add401) layer_norm140: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add407, gpt_neox_layers_4_input_layernorm_weight5, gpt_neox_layers_4_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv346 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight5, gpt_neox_layers_4_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims276: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv346, axes=None) matmul276: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm140, permute_dims276, out_dtype="void") add408: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul276, gpt_neox_layers_4_attention_query_key_value_bias5) reshape272: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add408, R.shape([1, seq_len, 24, 256])) reshape273: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape272, R.shape([seq_len, 24, 256])) lv347 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape273), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape274: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv347, R.shape([1, seq_len, 8, 256])) reshape275: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape274, R.shape([1, seq_len, 2048])) lv348 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight5, gpt_neox_layers_4_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims277: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv348, axes=None) matmul277: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape275, permute_dims277, out_dtype="void") add409: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul277, gpt_neox_layers_4_attention_dense_bias5) layer_norm141: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add407, gpt_neox_layers_4_post_attention_layernorm_weight5, gpt_neox_layers_4_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv349 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims278: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv349, axes=None) matmul278: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm141, permute_dims278, out_dtype="float32") add410: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul278, gpt_neox_layers_4_mlp_dense_h_to_4h_bias5) gelu68: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add410) astype140: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu68, dtype="float16") lv350 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims279: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv350, axes=None) matmul279: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype140, permute_dims279, out_dtype="float32") add411: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul279, gpt_neox_layers_4_mlp_dense_4h_to_h_bias5) astype141: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add411, dtype="float16") add412: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype141, add409) add413: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add412, add407) layer_norm142: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add413, gpt_neox_layers_5_input_layernorm_weight5, gpt_neox_layers_5_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv351 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight5, gpt_neox_layers_5_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims280: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv351, axes=None) matmul280: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm142, permute_dims280, out_dtype="void") add414: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul280, gpt_neox_layers_5_attention_query_key_value_bias5) reshape276: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add414, R.shape([1, seq_len, 24, 256])) reshape277: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape276, R.shape([seq_len, 24, 256])) lv352 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape277), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape278: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv352, R.shape([1, seq_len, 8, 256])) reshape279: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape278, R.shape([1, seq_len, 2048])) lv353 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight5, gpt_neox_layers_5_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims281: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv353, axes=None) matmul281: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape279, permute_dims281, out_dtype="void") add415: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul281, gpt_neox_layers_5_attention_dense_bias5) layer_norm143: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add413, gpt_neox_layers_5_post_attention_layernorm_weight5, gpt_neox_layers_5_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv354 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims282: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv354, axes=None) matmul282: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm143, permute_dims282, out_dtype="float32") add416: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul282, gpt_neox_layers_5_mlp_dense_h_to_4h_bias5) gelu69: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add416) astype142: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu69, dtype="float16") lv355 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims283: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv355, axes=None) matmul283: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype142, permute_dims283, out_dtype="float32") add417: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul283, gpt_neox_layers_5_mlp_dense_4h_to_h_bias5) astype143: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add417, dtype="float16") add418: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype143, add415) add419: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add418, add413) layer_norm144: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add419, gpt_neox_layers_6_input_layernorm_weight5, gpt_neox_layers_6_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv356 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight5, gpt_neox_layers_6_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims284: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv356, axes=None) matmul284: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm144, permute_dims284, out_dtype="void") add420: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul284, gpt_neox_layers_6_attention_query_key_value_bias5) reshape280: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add420, R.shape([1, seq_len, 24, 256])) reshape281: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape280, R.shape([seq_len, 24, 256])) lv357 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape281), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape282: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv357, R.shape([1, seq_len, 8, 256])) reshape283: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape282, R.shape([1, seq_len, 2048])) lv358 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight5, gpt_neox_layers_6_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims285: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv358, axes=None) matmul285: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape283, permute_dims285, out_dtype="void") add421: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul285, gpt_neox_layers_6_attention_dense_bias5) layer_norm145: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add419, gpt_neox_layers_6_post_attention_layernorm_weight5, gpt_neox_layers_6_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv359 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims286: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv359, axes=None) matmul286: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm145, permute_dims286, out_dtype="float32") add422: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul286, gpt_neox_layers_6_mlp_dense_h_to_4h_bias5) gelu70: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add422) astype144: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu70, dtype="float16") lv360 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims287: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv360, axes=None) matmul287: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype144, permute_dims287, out_dtype="float32") add423: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul287, gpt_neox_layers_6_mlp_dense_4h_to_h_bias5) astype145: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add423, dtype="float16") add424: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype145, add421) add425: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add424, add419) layer_norm146: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add425, gpt_neox_layers_7_input_layernorm_weight5, gpt_neox_layers_7_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv361 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight5, gpt_neox_layers_7_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims288: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv361, axes=None) matmul288: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm146, permute_dims288, out_dtype="void") add426: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul288, gpt_neox_layers_7_attention_query_key_value_bias5) reshape284: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add426, R.shape([1, seq_len, 24, 256])) reshape285: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape284, R.shape([seq_len, 24, 256])) lv362 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape285), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape286: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv362, R.shape([1, seq_len, 8, 256])) reshape287: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape286, R.shape([1, seq_len, 2048])) lv363 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight5, gpt_neox_layers_7_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims289: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv363, axes=None) matmul289: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape287, permute_dims289, out_dtype="void") add427: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul289, gpt_neox_layers_7_attention_dense_bias5) layer_norm147: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add425, gpt_neox_layers_7_post_attention_layernorm_weight5, gpt_neox_layers_7_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv364 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims290: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv364, axes=None) matmul290: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm147, permute_dims290, out_dtype="float32") add428: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul290, gpt_neox_layers_7_mlp_dense_h_to_4h_bias5) gelu71: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add428) astype146: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu71, dtype="float16") lv365 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims291: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv365, axes=None) matmul291: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype146, permute_dims291, out_dtype="float32") add429: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul291, gpt_neox_layers_7_mlp_dense_4h_to_h_bias5) astype147: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add429, dtype="float16") add430: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype147, add427) add431: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add430, add425) layer_norm148: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add431, gpt_neox_layers_8_input_layernorm_weight5, gpt_neox_layers_8_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv366 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight5, gpt_neox_layers_8_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims292: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv366, axes=None) matmul292: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm148, permute_dims292, out_dtype="void") add432: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul292, gpt_neox_layers_8_attention_query_key_value_bias5) reshape288: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add432, R.shape([1, seq_len, 24, 256])) reshape289: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape288, R.shape([seq_len, 24, 256])) lv367 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape289), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape290: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv367, R.shape([1, seq_len, 8, 256])) reshape291: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape290, R.shape([1, seq_len, 2048])) lv368 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight5, gpt_neox_layers_8_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims293: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv368, axes=None) matmul293: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape291, permute_dims293, out_dtype="void") add433: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul293, gpt_neox_layers_8_attention_dense_bias5) layer_norm149: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add431, gpt_neox_layers_8_post_attention_layernorm_weight5, gpt_neox_layers_8_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv369 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims294: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv369, axes=None) matmul294: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm149, permute_dims294, out_dtype="float32") add434: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul294, gpt_neox_layers_8_mlp_dense_h_to_4h_bias5) gelu72: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add434) astype148: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu72, dtype="float16") lv370 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims295: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv370, axes=None) matmul295: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype148, permute_dims295, out_dtype="float32") add435: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul295, gpt_neox_layers_8_mlp_dense_4h_to_h_bias5) astype149: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add435, dtype="float16") add436: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype149, add433) add437: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add436, add431) layer_norm150: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add437, gpt_neox_layers_9_input_layernorm_weight5, gpt_neox_layers_9_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv371 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight5, gpt_neox_layers_9_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims296: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv371, axes=None) matmul296: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm150, permute_dims296, out_dtype="void") add438: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul296, gpt_neox_layers_9_attention_query_key_value_bias5) reshape292: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add438, R.shape([1, seq_len, 24, 256])) reshape293: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape292, R.shape([seq_len, 24, 256])) lv372 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape293), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape294: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv372, R.shape([1, seq_len, 8, 256])) reshape295: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape294, R.shape([1, seq_len, 2048])) lv373 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight5, gpt_neox_layers_9_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims297: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv373, axes=None) matmul297: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape295, permute_dims297, out_dtype="void") add439: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul297, gpt_neox_layers_9_attention_dense_bias5) layer_norm151: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add437, gpt_neox_layers_9_post_attention_layernorm_weight5, gpt_neox_layers_9_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv374 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims298: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv374, axes=None) matmul298: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm151, permute_dims298, out_dtype="float32") add440: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul298, gpt_neox_layers_9_mlp_dense_h_to_4h_bias5) gelu73: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add440) astype150: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu73, dtype="float16") lv375 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims299: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv375, axes=None) matmul299: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype150, permute_dims299, out_dtype="float32") add441: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul299, gpt_neox_layers_9_mlp_dense_4h_to_h_bias5) astype151: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add441, dtype="float16") add442: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype151, add439) add443: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add442, add437) layer_norm152: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add443, gpt_neox_layers_10_input_layernorm_weight5, gpt_neox_layers_10_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv376 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight5, gpt_neox_layers_10_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims300: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv376, axes=None) matmul300: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm152, permute_dims300, out_dtype="void") add444: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul300, gpt_neox_layers_10_attention_query_key_value_bias5) reshape296: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add444, R.shape([1, seq_len, 24, 256])) reshape297: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape296, R.shape([seq_len, 24, 256])) lv377 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape297), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape298: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv377, R.shape([1, seq_len, 8, 256])) reshape299: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape298, R.shape([1, seq_len, 2048])) lv378 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight5, gpt_neox_layers_10_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims301: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv378, axes=None) matmul301: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape299, permute_dims301, out_dtype="void") add445: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul301, gpt_neox_layers_10_attention_dense_bias5) layer_norm153: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add443, gpt_neox_layers_10_post_attention_layernorm_weight5, gpt_neox_layers_10_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv379 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims302: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv379, axes=None) matmul302: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm153, permute_dims302, out_dtype="float32") add446: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul302, gpt_neox_layers_10_mlp_dense_h_to_4h_bias5) gelu74: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add446) astype152: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu74, dtype="float16") lv380 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims303: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv380, axes=None) matmul303: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype152, permute_dims303, out_dtype="float32") add447: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul303, gpt_neox_layers_10_mlp_dense_4h_to_h_bias5) astype153: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add447, dtype="float16") add448: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype153, add445) add449: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add448, add443) layer_norm154: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add449, gpt_neox_layers_11_input_layernorm_weight5, gpt_neox_layers_11_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv381 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight5, gpt_neox_layers_11_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims304: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv381, axes=None) matmul304: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm154, permute_dims304, out_dtype="void") add450: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul304, gpt_neox_layers_11_attention_query_key_value_bias5) reshape300: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add450, R.shape([1, seq_len, 24, 256])) reshape301: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape300, R.shape([seq_len, 24, 256])) lv382 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape301), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape302: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv382, R.shape([1, seq_len, 8, 256])) reshape303: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape302, R.shape([1, seq_len, 2048])) lv383 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight5, gpt_neox_layers_11_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims305: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv383, axes=None) matmul305: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape303, permute_dims305, out_dtype="void") add451: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul305, gpt_neox_layers_11_attention_dense_bias5) layer_norm155: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add449, gpt_neox_layers_11_post_attention_layernorm_weight5, gpt_neox_layers_11_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv384 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims306: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv384, axes=None) matmul306: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm155, permute_dims306, out_dtype="float32") add452: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul306, gpt_neox_layers_11_mlp_dense_h_to_4h_bias5) gelu75: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add452) astype154: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu75, dtype="float16") lv385 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims307: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv385, axes=None) matmul307: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype154, permute_dims307, out_dtype="float32") add453: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul307, gpt_neox_layers_11_mlp_dense_4h_to_h_bias5) astype155: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add453, dtype="float16") add454: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype155, add451) add455: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add454, add449) layer_norm156: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add455, gpt_neox_layers_12_input_layernorm_weight5, gpt_neox_layers_12_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv386 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight5, gpt_neox_layers_12_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims308: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv386, axes=None) matmul308: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm156, permute_dims308, out_dtype="void") add456: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul308, gpt_neox_layers_12_attention_query_key_value_bias5) reshape304: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add456, R.shape([1, seq_len, 24, 256])) reshape305: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape304, R.shape([seq_len, 24, 256])) lv387 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape305), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape306: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv387, R.shape([1, seq_len, 8, 256])) reshape307: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape306, R.shape([1, seq_len, 2048])) lv388 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight5, gpt_neox_layers_12_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims309: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv388, axes=None) matmul309: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape307, permute_dims309, out_dtype="void") add457: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul309, gpt_neox_layers_12_attention_dense_bias5) layer_norm157: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add455, gpt_neox_layers_12_post_attention_layernorm_weight5, gpt_neox_layers_12_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv389 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims310: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv389, axes=None) matmul310: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm157, permute_dims310, out_dtype="float32") add458: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul310, gpt_neox_layers_12_mlp_dense_h_to_4h_bias5) gelu76: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add458) astype156: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu76, dtype="float16") lv390 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims311: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv390, axes=None) matmul311: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype156, permute_dims311, out_dtype="float32") add459: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul311, gpt_neox_layers_12_mlp_dense_4h_to_h_bias5) astype157: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add459, dtype="float16") add460: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype157, add457) add461: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add460, add455) layer_norm158: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add461, gpt_neox_layers_13_input_layernorm_weight5, gpt_neox_layers_13_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv391 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight5, gpt_neox_layers_13_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims312: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv391, axes=None) matmul312: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm158, permute_dims312, out_dtype="void") add462: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul312, gpt_neox_layers_13_attention_query_key_value_bias5) reshape308: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add462, R.shape([1, seq_len, 24, 256])) reshape309: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape308, R.shape([seq_len, 24, 256])) lv392 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape309), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape310: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv392, R.shape([1, seq_len, 8, 256])) reshape311: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape310, R.shape([1, seq_len, 2048])) lv393 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight5, gpt_neox_layers_13_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims313: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv393, axes=None) matmul313: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape311, permute_dims313, out_dtype="void") add463: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul313, gpt_neox_layers_13_attention_dense_bias5) layer_norm159: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add461, gpt_neox_layers_13_post_attention_layernorm_weight5, gpt_neox_layers_13_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv394 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims314: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv394, axes=None) matmul314: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm159, permute_dims314, out_dtype="float32") add464: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul314, gpt_neox_layers_13_mlp_dense_h_to_4h_bias5) gelu77: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add464) astype158: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu77, dtype="float16") lv395 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims315: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv395, axes=None) matmul315: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype158, permute_dims315, out_dtype="float32") add465: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul315, gpt_neox_layers_13_mlp_dense_4h_to_h_bias5) astype159: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add465, dtype="float16") add466: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype159, add463) add467: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add466, add461) layer_norm160: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add467, gpt_neox_layers_14_input_layernorm_weight5, gpt_neox_layers_14_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv396 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight5, gpt_neox_layers_14_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims316: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv396, axes=None) matmul316: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm160, permute_dims316, out_dtype="void") add468: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul316, gpt_neox_layers_14_attention_query_key_value_bias5) reshape312: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add468, R.shape([1, seq_len, 24, 256])) reshape313: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape312, R.shape([seq_len, 24, 256])) lv397 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape313), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape314: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv397, R.shape([1, seq_len, 8, 256])) reshape315: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape314, R.shape([1, seq_len, 2048])) lv398 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight5, gpt_neox_layers_14_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims317: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv398, axes=None) matmul317: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape315, permute_dims317, out_dtype="void") add469: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul317, gpt_neox_layers_14_attention_dense_bias5) layer_norm161: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add467, gpt_neox_layers_14_post_attention_layernorm_weight5, gpt_neox_layers_14_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv399 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims318: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv399, axes=None) matmul318: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm161, permute_dims318, out_dtype="float32") add470: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul318, gpt_neox_layers_14_mlp_dense_h_to_4h_bias5) gelu78: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add470) astype160: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu78, dtype="float16") lv400 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims319: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv400, axes=None) matmul319: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype160, permute_dims319, out_dtype="float32") add471: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul319, gpt_neox_layers_14_mlp_dense_4h_to_h_bias5) astype161: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add471, dtype="float16") add472: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype161, add469) add473: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add472, add467) layer_norm162: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add473, gpt_neox_layers_15_input_layernorm_weight5, gpt_neox_layers_15_input_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv401 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight5, gpt_neox_layers_15_attention_query_key_value_q_scale5), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims320: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv401, axes=None) matmul320: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm162, permute_dims320, out_dtype="void") add474: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul320, gpt_neox_layers_15_attention_query_key_value_bias5) reshape316: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add474, R.shape([1, seq_len, 24, 256])) reshape317: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape316, R.shape([seq_len, 24, 256])) lv402 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape317), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape318: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv402, R.shape([1, seq_len, 8, 256])) reshape319: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape318, R.shape([1, seq_len, 2048])) lv403 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight5, gpt_neox_layers_15_attention_dense_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims321: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv403, axes=None) matmul321: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape319, permute_dims321, out_dtype="void") add475: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul321, gpt_neox_layers_15_attention_dense_bias5) layer_norm163: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add473, gpt_neox_layers_15_post_attention_layernorm_weight5, gpt_neox_layers_15_post_attention_layernorm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv404 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight5, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale5), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims322: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv404, axes=None) matmul322: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm163, permute_dims322, out_dtype="float32") add476: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul322, gpt_neox_layers_15_mlp_dense_h_to_4h_bias5) gelu79: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add476) astype162: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu79, dtype="float16") lv405 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight5, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale5), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims323: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv405, axes=None) matmul323: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype162, permute_dims323, out_dtype="float32") add477: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul323, gpt_neox_layers_15_mlp_dense_4h_to_h_bias5) astype163: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add477, dtype="float16") add478: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype163, add475) add479: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add478, add473) layer_norm164: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add479, gpt_neox_final_layer_norm_weight5, gpt_neox_final_layer_norm_bias5, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv406 = R.call_tir(cls.dequantize, (embed_out_q_weight5, embed_out_q_scale5), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) permute_dims324: R.Tensor((2048, vocab_size), dtype="float16") = R.permute_dims(lv406, axes=None) matmul324: R.Tensor((1, seq_len, vocab_size), dtype="float16") = R.matmul(layer_norm164, permute_dims324, out_dtype="void") astype164: R.Tensor((1, seq_len, vocab_size), dtype="float32") = R.astype(matmul324, dtype="float32") gv5: R.Tuple(R.Tensor((1, seq_len, vocab_size), dtype="float32"), R.Object) = astype164, paged_kv_cache R.output(gv5) return gv5 @R.function def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object: max_batch_size = T.int64() max_total_seq_len = T.int64() prefill_chunk_size = T.int64() page_size = T.int64() support_sliding_window = T.int64() R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module gv: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16") paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 16]), R.prim_value(8), R.prim_value(8), R.prim_value(256), R.prim_value(1), R.prim_value(1), R.prim_value(10000), gv, cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, R.prim_value(0), R.prim_value(0), sinfo_args=(R.Object,)) return paged_kv_cache @R.function def decode(input_embed: R.Tensor((1, 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, "vocab_size"), dtype="float32"), R.Object): vocab_size = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight2: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale2: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] gpt_neox_layers_0_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight2: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale2: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias2: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias2: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight2: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale2: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias2: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight2: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale2: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias2: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias2: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight2: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale2: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm33: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(input_embed, gpt_neox_layers_0_input_layernorm_weight2, gpt_neox_layers_0_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv83 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight2, gpt_neox_layers_0_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims65: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv83, axes=None) matmul65: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm33, permute_dims65, out_dtype="void") add96: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul65, gpt_neox_layers_0_attention_query_key_value_bias2) reshape64: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add96, R.shape([1, 1, 24, 256])) reshape65: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape64, R.shape([1, 24, 256])) lv84 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape65), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape66: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv84, R.shape([1, 1, 8, 256])) reshape67: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape66, R.shape([1, 1, 2048])) lv85 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight2, gpt_neox_layers_0_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims66: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv85, axes=None) matmul66: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape67, permute_dims66, out_dtype="void") add97: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul66, gpt_neox_layers_0_attention_dense_bias2) layer_norm34: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(input_embed, gpt_neox_layers_0_post_attention_layernorm_weight2, gpt_neox_layers_0_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv86 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims67: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv86, axes=None) matmul67: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm34, permute_dims67, out_dtype="float32") add98: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul67, gpt_neox_layers_0_mlp_dense_h_to_4h_bias2) gelu16: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add98) astype33: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu16, dtype="float16") lv87 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims68: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv87, axes=None) matmul68: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype33, permute_dims68, out_dtype="float32") add99: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul68, gpt_neox_layers_0_mlp_dense_4h_to_h_bias2) astype34: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add99, dtype="float16") add100: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype34, add97) add101: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add100, input_embed) layer_norm35: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add101, gpt_neox_layers_1_input_layernorm_weight2, gpt_neox_layers_1_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv88 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight2, gpt_neox_layers_1_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims69: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv88, axes=None) matmul69: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm35, permute_dims69, out_dtype="void") add102: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul69, gpt_neox_layers_1_attention_query_key_value_bias2) reshape68: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add102, R.shape([1, 1, 24, 256])) reshape69: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape68, R.shape([1, 24, 256])) lv89 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape69), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape70: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv89, R.shape([1, 1, 8, 256])) reshape71: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape70, R.shape([1, 1, 2048])) lv90 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight2, gpt_neox_layers_1_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims70: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv90, axes=None) matmul70: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape71, permute_dims70, out_dtype="void") add103: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul70, gpt_neox_layers_1_attention_dense_bias2) layer_norm36: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add101, gpt_neox_layers_1_post_attention_layernorm_weight2, gpt_neox_layers_1_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv91 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims71: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv91, axes=None) matmul71: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm36, permute_dims71, out_dtype="float32") add104: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul71, gpt_neox_layers_1_mlp_dense_h_to_4h_bias2) gelu17: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add104) astype35: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu17, dtype="float16") lv92 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims72: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv92, axes=None) matmul72: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype35, permute_dims72, out_dtype="float32") add105: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul72, gpt_neox_layers_1_mlp_dense_4h_to_h_bias2) astype36: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add105, dtype="float16") add106: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype36, add103) add107: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add106, add101) layer_norm37: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add107, gpt_neox_layers_2_input_layernorm_weight2, gpt_neox_layers_2_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv93 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight2, gpt_neox_layers_2_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims73: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv93, axes=None) matmul73: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm37, permute_dims73, out_dtype="void") add108: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul73, gpt_neox_layers_2_attention_query_key_value_bias2) reshape72: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add108, R.shape([1, 1, 24, 256])) reshape73: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape72, R.shape([1, 24, 256])) lv94 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape73), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape74: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv94, R.shape([1, 1, 8, 256])) reshape75: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape74, R.shape([1, 1, 2048])) lv95 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight2, gpt_neox_layers_2_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims74: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv95, axes=None) matmul74: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape75, permute_dims74, out_dtype="void") add109: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul74, gpt_neox_layers_2_attention_dense_bias2) layer_norm38: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add107, gpt_neox_layers_2_post_attention_layernorm_weight2, gpt_neox_layers_2_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv96 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims75: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv96, axes=None) matmul75: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm38, permute_dims75, out_dtype="float32") add110: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul75, gpt_neox_layers_2_mlp_dense_h_to_4h_bias2) gelu18: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add110) astype37: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu18, dtype="float16") lv97 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims76: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv97, axes=None) matmul76: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype37, permute_dims76, out_dtype="float32") add111: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul76, gpt_neox_layers_2_mlp_dense_4h_to_h_bias2) astype38: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add111, dtype="float16") add112: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype38, add109) add113: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add112, add107) layer_norm39: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add113, gpt_neox_layers_3_input_layernorm_weight2, gpt_neox_layers_3_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv98 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight2, gpt_neox_layers_3_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims77: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv98, axes=None) matmul77: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm39, permute_dims77, out_dtype="void") add114: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul77, gpt_neox_layers_3_attention_query_key_value_bias2) reshape76: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add114, R.shape([1, 1, 24, 256])) reshape77: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape76, R.shape([1, 24, 256])) lv99 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape77), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape78: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv99, R.shape([1, 1, 8, 256])) reshape79: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape78, R.shape([1, 1, 2048])) lv100 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight2, gpt_neox_layers_3_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims78: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv100, axes=None) matmul78: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape79, permute_dims78, out_dtype="void") add115: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul78, gpt_neox_layers_3_attention_dense_bias2) layer_norm40: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add113, gpt_neox_layers_3_post_attention_layernorm_weight2, gpt_neox_layers_3_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv101 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims79: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv101, axes=None) matmul79: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm40, permute_dims79, out_dtype="float32") add116: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul79, gpt_neox_layers_3_mlp_dense_h_to_4h_bias2) gelu19: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add116) astype39: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu19, dtype="float16") lv102 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims80: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv102, axes=None) matmul80: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype39, permute_dims80, out_dtype="float32") add117: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul80, gpt_neox_layers_3_mlp_dense_4h_to_h_bias2) astype40: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add117, dtype="float16") add118: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype40, add115) add119: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add118, add113) layer_norm41: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add119, gpt_neox_layers_4_input_layernorm_weight2, gpt_neox_layers_4_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv103 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight2, gpt_neox_layers_4_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims81: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv103, axes=None) matmul81: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm41, permute_dims81, out_dtype="void") add120: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul81, gpt_neox_layers_4_attention_query_key_value_bias2) reshape80: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add120, R.shape([1, 1, 24, 256])) reshape81: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape80, R.shape([1, 24, 256])) lv104 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape81), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape82: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv104, R.shape([1, 1, 8, 256])) reshape83: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape82, R.shape([1, 1, 2048])) lv105 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight2, gpt_neox_layers_4_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims82: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv105, axes=None) matmul82: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape83, permute_dims82, out_dtype="void") add121: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul82, gpt_neox_layers_4_attention_dense_bias2) layer_norm42: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add119, gpt_neox_layers_4_post_attention_layernorm_weight2, gpt_neox_layers_4_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv106 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims83: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv106, axes=None) matmul83: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm42, permute_dims83, out_dtype="float32") add122: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul83, gpt_neox_layers_4_mlp_dense_h_to_4h_bias2) gelu20: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add122) astype41: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu20, dtype="float16") lv107 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims84: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv107, axes=None) matmul84: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype41, permute_dims84, out_dtype="float32") add123: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul84, gpt_neox_layers_4_mlp_dense_4h_to_h_bias2) astype42: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add123, dtype="float16") add124: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype42, add121) add125: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add124, add119) layer_norm43: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add125, gpt_neox_layers_5_input_layernorm_weight2, gpt_neox_layers_5_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv108 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight2, gpt_neox_layers_5_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims85: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv108, axes=None) matmul85: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm43, permute_dims85, out_dtype="void") add126: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul85, gpt_neox_layers_5_attention_query_key_value_bias2) reshape84: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add126, R.shape([1, 1, 24, 256])) reshape85: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape84, R.shape([1, 24, 256])) lv109 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape85), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape86: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv109, R.shape([1, 1, 8, 256])) reshape87: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape86, R.shape([1, 1, 2048])) lv110 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight2, gpt_neox_layers_5_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims86: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv110, axes=None) matmul86: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape87, permute_dims86, out_dtype="void") add127: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul86, gpt_neox_layers_5_attention_dense_bias2) layer_norm44: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add125, gpt_neox_layers_5_post_attention_layernorm_weight2, gpt_neox_layers_5_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv111 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims87: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv111, axes=None) matmul87: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm44, permute_dims87, out_dtype="float32") add128: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul87, gpt_neox_layers_5_mlp_dense_h_to_4h_bias2) gelu21: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add128) astype43: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu21, dtype="float16") lv112 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims88: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv112, axes=None) matmul88: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype43, permute_dims88, out_dtype="float32") add129: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul88, gpt_neox_layers_5_mlp_dense_4h_to_h_bias2) astype44: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add129, dtype="float16") add130: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype44, add127) add131: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add130, add125) layer_norm45: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add131, gpt_neox_layers_6_input_layernorm_weight2, gpt_neox_layers_6_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv113 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight2, gpt_neox_layers_6_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims89: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv113, axes=None) matmul89: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm45, permute_dims89, out_dtype="void") add132: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul89, gpt_neox_layers_6_attention_query_key_value_bias2) reshape88: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add132, R.shape([1, 1, 24, 256])) reshape89: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape88, R.shape([1, 24, 256])) lv114 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape89), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape90: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv114, R.shape([1, 1, 8, 256])) reshape91: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape90, R.shape([1, 1, 2048])) lv115 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight2, gpt_neox_layers_6_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims90: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv115, axes=None) matmul90: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape91, permute_dims90, out_dtype="void") add133: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul90, gpt_neox_layers_6_attention_dense_bias2) layer_norm46: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add131, gpt_neox_layers_6_post_attention_layernorm_weight2, gpt_neox_layers_6_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv116 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims91: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv116, axes=None) matmul91: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm46, permute_dims91, out_dtype="float32") add134: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul91, gpt_neox_layers_6_mlp_dense_h_to_4h_bias2) gelu22: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add134) astype45: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu22, dtype="float16") lv117 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims92: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv117, axes=None) matmul92: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype45, permute_dims92, out_dtype="float32") add135: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul92, gpt_neox_layers_6_mlp_dense_4h_to_h_bias2) astype46: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add135, dtype="float16") add136: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype46, add133) add137: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add136, add131) layer_norm47: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add137, gpt_neox_layers_7_input_layernorm_weight2, gpt_neox_layers_7_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv118 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight2, gpt_neox_layers_7_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims93: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv118, axes=None) matmul93: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm47, permute_dims93, out_dtype="void") add138: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul93, gpt_neox_layers_7_attention_query_key_value_bias2) reshape92: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add138, R.shape([1, 1, 24, 256])) reshape93: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape92, R.shape([1, 24, 256])) lv119 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape93), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape94: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv119, R.shape([1, 1, 8, 256])) reshape95: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape94, R.shape([1, 1, 2048])) lv120 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight2, gpt_neox_layers_7_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims94: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv120, axes=None) matmul94: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape95, permute_dims94, out_dtype="void") add139: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul94, gpt_neox_layers_7_attention_dense_bias2) layer_norm48: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add137, gpt_neox_layers_7_post_attention_layernorm_weight2, gpt_neox_layers_7_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv121 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims95: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv121, axes=None) matmul95: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm48, permute_dims95, out_dtype="float32") add140: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul95, gpt_neox_layers_7_mlp_dense_h_to_4h_bias2) gelu23: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add140) astype47: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu23, dtype="float16") lv122 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims96: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv122, axes=None) matmul96: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype47, permute_dims96, out_dtype="float32") add141: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul96, gpt_neox_layers_7_mlp_dense_4h_to_h_bias2) astype48: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add141, dtype="float16") add142: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype48, add139) add143: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add142, add137) layer_norm49: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add143, gpt_neox_layers_8_input_layernorm_weight2, gpt_neox_layers_8_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv123 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight2, gpt_neox_layers_8_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims97: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv123, axes=None) matmul97: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm49, permute_dims97, out_dtype="void") add144: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul97, gpt_neox_layers_8_attention_query_key_value_bias2) reshape96: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add144, R.shape([1, 1, 24, 256])) reshape97: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape96, R.shape([1, 24, 256])) lv124 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape97), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape98: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv124, R.shape([1, 1, 8, 256])) reshape99: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape98, R.shape([1, 1, 2048])) lv125 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight2, gpt_neox_layers_8_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims98: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv125, axes=None) matmul98: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape99, permute_dims98, out_dtype="void") add145: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul98, gpt_neox_layers_8_attention_dense_bias2) layer_norm50: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add143, gpt_neox_layers_8_post_attention_layernorm_weight2, gpt_neox_layers_8_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv126 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims99: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv126, axes=None) matmul99: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm50, permute_dims99, out_dtype="float32") add146: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul99, gpt_neox_layers_8_mlp_dense_h_to_4h_bias2) gelu24: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add146) astype49: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu24, dtype="float16") lv127 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims100: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv127, axes=None) matmul100: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype49, permute_dims100, out_dtype="float32") add147: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul100, gpt_neox_layers_8_mlp_dense_4h_to_h_bias2) astype50: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add147, dtype="float16") add148: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype50, add145) add149: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add148, add143) layer_norm51: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add149, gpt_neox_layers_9_input_layernorm_weight2, gpt_neox_layers_9_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv128 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight2, gpt_neox_layers_9_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims101: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv128, axes=None) matmul101: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm51, permute_dims101, out_dtype="void") add150: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul101, gpt_neox_layers_9_attention_query_key_value_bias2) reshape100: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add150, R.shape([1, 1, 24, 256])) reshape101: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape100, R.shape([1, 24, 256])) lv129 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape101), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape102: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv129, R.shape([1, 1, 8, 256])) reshape103: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape102, R.shape([1, 1, 2048])) lv130 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight2, gpt_neox_layers_9_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims102: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv130, axes=None) matmul102: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape103, permute_dims102, out_dtype="void") add151: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul102, gpt_neox_layers_9_attention_dense_bias2) layer_norm52: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add149, gpt_neox_layers_9_post_attention_layernorm_weight2, gpt_neox_layers_9_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv131 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims103: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv131, axes=None) matmul103: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm52, permute_dims103, out_dtype="float32") add152: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul103, gpt_neox_layers_9_mlp_dense_h_to_4h_bias2) gelu25: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add152) astype51: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu25, dtype="float16") lv132 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims104: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv132, axes=None) matmul104: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype51, permute_dims104, out_dtype="float32") add153: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul104, gpt_neox_layers_9_mlp_dense_4h_to_h_bias2) astype52: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add153, dtype="float16") add154: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype52, add151) add155: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add154, add149) layer_norm53: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add155, gpt_neox_layers_10_input_layernorm_weight2, gpt_neox_layers_10_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv133 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight2, gpt_neox_layers_10_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims105: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv133, axes=None) matmul105: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm53, permute_dims105, out_dtype="void") add156: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul105, gpt_neox_layers_10_attention_query_key_value_bias2) reshape104: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add156, R.shape([1, 1, 24, 256])) reshape105: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape104, R.shape([1, 24, 256])) lv134 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape105), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape106: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv134, R.shape([1, 1, 8, 256])) reshape107: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape106, R.shape([1, 1, 2048])) lv135 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight2, gpt_neox_layers_10_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims106: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv135, axes=None) matmul106: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape107, permute_dims106, out_dtype="void") add157: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul106, gpt_neox_layers_10_attention_dense_bias2) layer_norm54: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add155, gpt_neox_layers_10_post_attention_layernorm_weight2, gpt_neox_layers_10_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv136 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims107: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv136, axes=None) matmul107: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm54, permute_dims107, out_dtype="float32") add158: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul107, gpt_neox_layers_10_mlp_dense_h_to_4h_bias2) gelu26: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add158) astype53: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu26, dtype="float16") lv137 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims108: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv137, axes=None) matmul108: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype53, permute_dims108, out_dtype="float32") add159: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul108, gpt_neox_layers_10_mlp_dense_4h_to_h_bias2) astype54: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add159, dtype="float16") add160: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype54, add157) add161: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add160, add155) layer_norm55: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add161, gpt_neox_layers_11_input_layernorm_weight2, gpt_neox_layers_11_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv138 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight2, gpt_neox_layers_11_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims109: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv138, axes=None) matmul109: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm55, permute_dims109, out_dtype="void") add162: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul109, gpt_neox_layers_11_attention_query_key_value_bias2) reshape108: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add162, R.shape([1, 1, 24, 256])) reshape109: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape108, R.shape([1, 24, 256])) lv139 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape109), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape110: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv139, R.shape([1, 1, 8, 256])) reshape111: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape110, R.shape([1, 1, 2048])) lv140 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight2, gpt_neox_layers_11_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims110: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv140, axes=None) matmul110: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape111, permute_dims110, out_dtype="void") add163: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul110, gpt_neox_layers_11_attention_dense_bias2) layer_norm56: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add161, gpt_neox_layers_11_post_attention_layernorm_weight2, gpt_neox_layers_11_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv141 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims111: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv141, axes=None) matmul111: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm56, permute_dims111, out_dtype="float32") add164: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul111, gpt_neox_layers_11_mlp_dense_h_to_4h_bias2) gelu27: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add164) astype55: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu27, dtype="float16") lv142 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims112: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv142, axes=None) matmul112: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype55, permute_dims112, out_dtype="float32") add165: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul112, gpt_neox_layers_11_mlp_dense_4h_to_h_bias2) astype56: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add165, dtype="float16") add166: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype56, add163) add167: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add166, add161) layer_norm57: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add167, gpt_neox_layers_12_input_layernorm_weight2, gpt_neox_layers_12_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv143 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight2, gpt_neox_layers_12_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims113: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv143, axes=None) matmul113: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm57, permute_dims113, out_dtype="void") add168: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul113, gpt_neox_layers_12_attention_query_key_value_bias2) reshape112: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add168, R.shape([1, 1, 24, 256])) reshape113: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape112, R.shape([1, 24, 256])) lv144 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape113), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape114: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv144, R.shape([1, 1, 8, 256])) reshape115: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape114, R.shape([1, 1, 2048])) lv145 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight2, gpt_neox_layers_12_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims114: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv145, axes=None) matmul114: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape115, permute_dims114, out_dtype="void") add169: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul114, gpt_neox_layers_12_attention_dense_bias2) layer_norm58: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add167, gpt_neox_layers_12_post_attention_layernorm_weight2, gpt_neox_layers_12_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv146 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims115: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv146, axes=None) matmul115: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm58, permute_dims115, out_dtype="float32") add170: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul115, gpt_neox_layers_12_mlp_dense_h_to_4h_bias2) gelu28: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add170) astype57: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu28, dtype="float16") lv147 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims116: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv147, axes=None) matmul116: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype57, permute_dims116, out_dtype="float32") add171: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul116, gpt_neox_layers_12_mlp_dense_4h_to_h_bias2) astype58: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add171, dtype="float16") add172: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype58, add169) add173: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add172, add167) layer_norm59: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add173, gpt_neox_layers_13_input_layernorm_weight2, gpt_neox_layers_13_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv148 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight2, gpt_neox_layers_13_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims117: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv148, axes=None) matmul117: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm59, permute_dims117, out_dtype="void") add174: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul117, gpt_neox_layers_13_attention_query_key_value_bias2) reshape116: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add174, R.shape([1, 1, 24, 256])) reshape117: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape116, R.shape([1, 24, 256])) lv149 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape117), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape118: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv149, R.shape([1, 1, 8, 256])) reshape119: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape118, R.shape([1, 1, 2048])) lv150 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight2, gpt_neox_layers_13_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims118: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv150, axes=None) matmul118: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape119, permute_dims118, out_dtype="void") add175: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul118, gpt_neox_layers_13_attention_dense_bias2) layer_norm60: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add173, gpt_neox_layers_13_post_attention_layernorm_weight2, gpt_neox_layers_13_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv151 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims119: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv151, axes=None) matmul119: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm60, permute_dims119, out_dtype="float32") add176: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul119, gpt_neox_layers_13_mlp_dense_h_to_4h_bias2) gelu29: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add176) astype59: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu29, dtype="float16") lv152 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims120: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv152, axes=None) matmul120: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype59, permute_dims120, out_dtype="float32") add177: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul120, gpt_neox_layers_13_mlp_dense_4h_to_h_bias2) astype60: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add177, dtype="float16") add178: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype60, add175) add179: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add178, add173) layer_norm61: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add179, gpt_neox_layers_14_input_layernorm_weight2, gpt_neox_layers_14_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv153 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight2, gpt_neox_layers_14_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims121: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv153, axes=None) matmul121: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm61, permute_dims121, out_dtype="void") add180: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul121, gpt_neox_layers_14_attention_query_key_value_bias2) reshape120: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add180, R.shape([1, 1, 24, 256])) reshape121: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape120, R.shape([1, 24, 256])) lv154 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape121), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape122: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv154, R.shape([1, 1, 8, 256])) reshape123: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape122, R.shape([1, 1, 2048])) lv155 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight2, gpt_neox_layers_14_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims122: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv155, axes=None) matmul122: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape123, permute_dims122, out_dtype="void") add181: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul122, gpt_neox_layers_14_attention_dense_bias2) layer_norm62: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add179, gpt_neox_layers_14_post_attention_layernorm_weight2, gpt_neox_layers_14_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv156 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims123: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv156, axes=None) matmul123: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm62, permute_dims123, out_dtype="float32") add182: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul123, gpt_neox_layers_14_mlp_dense_h_to_4h_bias2) gelu30: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add182) astype61: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu30, dtype="float16") lv157 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims124: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv157, axes=None) matmul124: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype61, permute_dims124, out_dtype="float32") add183: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul124, gpt_neox_layers_14_mlp_dense_4h_to_h_bias2) astype62: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add183, dtype="float16") add184: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype62, add181) add185: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add184, add179) layer_norm63: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add185, gpt_neox_layers_15_input_layernorm_weight2, gpt_neox_layers_15_input_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv158 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight2, gpt_neox_layers_15_attention_query_key_value_q_scale2), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims125: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv158, axes=None) matmul125: R.Tensor((1, 1, 6144), dtype="float16") = R.matmul(layer_norm63, permute_dims125, out_dtype="void") add186: R.Tensor((1, 1, 6144), dtype="float16") = R.add(matmul125, gpt_neox_layers_15_attention_query_key_value_bias2) reshape124: R.Tensor((1, 1, 24, 256), dtype="float16") = R.reshape(add186, R.shape([1, 1, 24, 256])) reshape125: R.Tensor((1, 24, 256), dtype="float16") = R.reshape(reshape124, R.shape([1, 24, 256])) lv159 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape125), out_sinfo=R.Tensor((1, 8, 256), dtype="float16")) reshape126: R.Tensor((1, 1, 8, 256), dtype="float16") = R.reshape(lv159, R.shape([1, 1, 8, 256])) reshape127: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape126, R.shape([1, 1, 2048])) lv160 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight2, gpt_neox_layers_15_attention_dense_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims126: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv160, axes=None) matmul126: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape127, permute_dims126, out_dtype="void") add187: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul126, gpt_neox_layers_15_attention_dense_bias2) layer_norm64: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add185, gpt_neox_layers_15_post_attention_layernorm_weight2, gpt_neox_layers_15_post_attention_layernorm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv161 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight2, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale2), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims127: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv161, axes=None) matmul127: R.Tensor((1, 1, 8192), dtype="float32") = R.matmul(layer_norm64, permute_dims127, out_dtype="float32") add188: R.Tensor((1, 1, 8192), dtype="float32") = R.add(matmul127, gpt_neox_layers_15_mlp_dense_h_to_4h_bias2) gelu31: R.Tensor((1, 1, 8192), dtype="float32") = R.nn.gelu(add188) astype63: R.Tensor((1, 1, 8192), dtype="float16") = R.astype(gelu31, dtype="float16") lv162 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight2, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale2), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims128: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv162, axes=None) matmul128: R.Tensor((1, 1, 2048), dtype="float32") = R.matmul(astype63, permute_dims128, out_dtype="float32") add189: R.Tensor((1, 1, 2048), dtype="float32") = R.add(matmul128, gpt_neox_layers_15_mlp_dense_4h_to_h_bias2) astype64: R.Tensor((1, 1, 2048), dtype="float16") = R.astype(add189, dtype="float16") add190: R.Tensor((1, 1, 2048), dtype="float16") = R.add(astype64, add187) add191: R.Tensor((1, 1, 2048), dtype="float16") = R.add(add190, add185) layer_norm65: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.layer_norm(add191, gpt_neox_final_layer_norm_weight2, gpt_neox_final_layer_norm_bias2, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv163 = R.call_tir(cls.dequantize, (embed_out_q_weight2, embed_out_q_scale2), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) permute_dims129: R.Tensor((2048, vocab_size), dtype="float16") = R.permute_dims(lv163, axes=None) matmul129: R.Tensor((1, 1, vocab_size), dtype="float16") = R.matmul(layer_norm65, permute_dims129, out_dtype="void") astype65: R.Tensor((1, 1, vocab_size), dtype="float32") = R.astype(matmul129, dtype="float32") gv2: R.Tuple(R.Tensor((1, 1, vocab_size), dtype="float32"), R.Object) = astype65, paged_kv_cache R.output(gv2) return gv2 @R.function def embed(input_ids: R.Tensor(("seq_len",), dtype="int32"), packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tensor(("seq_len", 2048), dtype="float16"): seq_len = T.int64() vocab_size = T.int64() R.func_attr({"num_input": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] gpt_neox_layers_0_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] lv = R.call_tir(cls.dequantize, (gpt_neox_embed_in_q_weight, gpt_neox_embed_in_q_scale), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) take: R.Tensor((seq_len, 2048), dtype="float16") = R.take(lv, input_ids, axis=0) gv: R.Tensor((seq_len, 2048), dtype="float16") = take R.output(gv) return gv @R.function def prefill(input_embed: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((8192, 256), dtype="uint32"), R.Tensor((8192, 64), dtype="float16"), R.Tensor((8192,), dtype="float32"), R.Tensor((2048, 1024), dtype="uint32"), R.Tensor((2048, 256), dtype="float16"), R.Tensor((2048,), dtype="float32"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor(("vocab_size", 256), dtype="uint32"), R.Tensor(("vocab_size", 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, "vocab_size"), dtype="float32"), R.Object): vocab_size = T.int64() seq_len = T.int64() R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): gpt_neox_embed_in_q_weight1: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[0] gpt_neox_embed_in_q_scale1: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[1] gpt_neox_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[2] gpt_neox_layers_0_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[3] gpt_neox_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[4] gpt_neox_layers_0_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[5] gpt_neox_layers_0_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[6] gpt_neox_layers_0_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[7] gpt_neox_layers_0_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[8] gpt_neox_layers_0_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[9] gpt_neox_layers_0_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[10] gpt_neox_layers_0_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[11] gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[12] gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[13] gpt_neox_layers_0_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[14] gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[15] gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[16] gpt_neox_layers_0_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[17] gpt_neox_layers_1_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[18] gpt_neox_layers_1_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[19] gpt_neox_layers_1_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[20] gpt_neox_layers_1_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[21] gpt_neox_layers_1_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[22] gpt_neox_layers_1_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[23] gpt_neox_layers_1_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[24] gpt_neox_layers_1_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[25] gpt_neox_layers_1_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[26] gpt_neox_layers_1_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[27] gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[28] gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[29] gpt_neox_layers_1_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[30] gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[31] gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[32] gpt_neox_layers_1_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[33] gpt_neox_layers_2_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[34] gpt_neox_layers_2_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[35] gpt_neox_layers_2_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[36] gpt_neox_layers_2_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[37] gpt_neox_layers_2_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[38] gpt_neox_layers_2_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[39] gpt_neox_layers_2_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[40] gpt_neox_layers_2_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[41] gpt_neox_layers_2_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[42] gpt_neox_layers_2_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[43] gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[44] gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[45] gpt_neox_layers_2_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[46] gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[47] gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[48] gpt_neox_layers_2_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[49] gpt_neox_layers_3_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[50] gpt_neox_layers_3_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[51] gpt_neox_layers_3_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[52] gpt_neox_layers_3_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[53] gpt_neox_layers_3_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[54] gpt_neox_layers_3_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[55] gpt_neox_layers_3_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[56] gpt_neox_layers_3_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[57] gpt_neox_layers_3_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[58] gpt_neox_layers_3_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[59] gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[60] gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[61] gpt_neox_layers_3_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[62] gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[63] gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[64] gpt_neox_layers_3_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[65] gpt_neox_layers_4_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[66] gpt_neox_layers_4_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[67] gpt_neox_layers_4_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[68] gpt_neox_layers_4_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[69] gpt_neox_layers_4_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[70] gpt_neox_layers_4_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[71] gpt_neox_layers_4_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[72] gpt_neox_layers_4_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[73] gpt_neox_layers_4_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[74] gpt_neox_layers_4_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[75] gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[76] gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[77] gpt_neox_layers_4_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[78] gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[79] gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[80] gpt_neox_layers_4_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[81] gpt_neox_layers_5_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[82] gpt_neox_layers_5_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[83] gpt_neox_layers_5_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[84] gpt_neox_layers_5_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[85] gpt_neox_layers_5_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[86] gpt_neox_layers_5_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[87] gpt_neox_layers_5_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[88] gpt_neox_layers_5_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[89] gpt_neox_layers_5_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[90] gpt_neox_layers_5_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[91] gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[92] gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[93] gpt_neox_layers_5_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[94] gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[95] gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[96] gpt_neox_layers_5_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[97] gpt_neox_layers_6_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[98] gpt_neox_layers_6_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[99] gpt_neox_layers_6_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[100] gpt_neox_layers_6_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[101] gpt_neox_layers_6_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[102] gpt_neox_layers_6_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[103] gpt_neox_layers_6_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[104] gpt_neox_layers_6_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[105] gpt_neox_layers_6_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[106] gpt_neox_layers_6_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[107] gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[108] gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[109] gpt_neox_layers_6_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[110] gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[111] gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[112] gpt_neox_layers_6_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[113] gpt_neox_layers_7_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[114] gpt_neox_layers_7_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[115] gpt_neox_layers_7_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[116] gpt_neox_layers_7_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[117] gpt_neox_layers_7_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[118] gpt_neox_layers_7_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[119] gpt_neox_layers_7_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[120] gpt_neox_layers_7_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[121] gpt_neox_layers_7_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[122] gpt_neox_layers_7_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[123] gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[124] gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[125] gpt_neox_layers_7_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[126] gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[127] gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[128] gpt_neox_layers_7_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[129] gpt_neox_layers_8_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[130] gpt_neox_layers_8_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[131] gpt_neox_layers_8_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[132] gpt_neox_layers_8_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[133] gpt_neox_layers_8_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[134] gpt_neox_layers_8_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[135] gpt_neox_layers_8_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[136] gpt_neox_layers_8_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[137] gpt_neox_layers_8_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[138] gpt_neox_layers_8_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[139] gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[140] gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[141] gpt_neox_layers_8_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[142] gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[143] gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[144] gpt_neox_layers_8_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[145] gpt_neox_layers_9_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[146] gpt_neox_layers_9_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[147] gpt_neox_layers_9_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[148] gpt_neox_layers_9_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[149] gpt_neox_layers_9_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[150] gpt_neox_layers_9_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[151] gpt_neox_layers_9_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[152] gpt_neox_layers_9_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[153] gpt_neox_layers_9_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[154] gpt_neox_layers_9_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[155] gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[156] gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[157] gpt_neox_layers_9_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[158] gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[159] gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[160] gpt_neox_layers_9_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[161] gpt_neox_layers_10_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[162] gpt_neox_layers_10_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[163] gpt_neox_layers_10_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[164] gpt_neox_layers_10_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[165] gpt_neox_layers_10_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[166] gpt_neox_layers_10_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[167] gpt_neox_layers_10_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[168] gpt_neox_layers_10_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[169] gpt_neox_layers_10_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[170] gpt_neox_layers_10_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[171] gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[172] gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[173] gpt_neox_layers_10_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[174] gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[175] gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[176] gpt_neox_layers_10_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[177] gpt_neox_layers_11_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[178] gpt_neox_layers_11_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[179] gpt_neox_layers_11_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[180] gpt_neox_layers_11_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[181] gpt_neox_layers_11_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[182] gpt_neox_layers_11_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[183] gpt_neox_layers_11_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[184] gpt_neox_layers_11_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[185] gpt_neox_layers_11_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[186] gpt_neox_layers_11_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[187] gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[188] gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[189] gpt_neox_layers_11_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[190] gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[191] gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[192] gpt_neox_layers_11_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[193] gpt_neox_layers_12_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[194] gpt_neox_layers_12_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[195] gpt_neox_layers_12_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[196] gpt_neox_layers_12_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[197] gpt_neox_layers_12_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[198] gpt_neox_layers_12_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[199] gpt_neox_layers_12_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[200] gpt_neox_layers_12_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[201] gpt_neox_layers_12_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[202] gpt_neox_layers_12_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[203] gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[204] gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[205] gpt_neox_layers_12_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[206] gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[207] gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[208] gpt_neox_layers_12_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[209] gpt_neox_layers_13_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[210] gpt_neox_layers_13_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[211] gpt_neox_layers_13_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[212] gpt_neox_layers_13_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[213] gpt_neox_layers_13_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[214] gpt_neox_layers_13_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[215] gpt_neox_layers_13_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[216] gpt_neox_layers_13_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[217] gpt_neox_layers_13_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[218] gpt_neox_layers_13_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[219] gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[220] gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[221] gpt_neox_layers_13_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[222] gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[223] gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[224] gpt_neox_layers_13_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[225] gpt_neox_layers_14_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[226] gpt_neox_layers_14_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[227] gpt_neox_layers_14_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[228] gpt_neox_layers_14_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[229] gpt_neox_layers_14_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[230] gpt_neox_layers_14_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[231] gpt_neox_layers_14_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[232] gpt_neox_layers_14_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[233] gpt_neox_layers_14_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[234] gpt_neox_layers_14_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[235] gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[236] gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[237] gpt_neox_layers_14_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[238] gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[239] gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[240] gpt_neox_layers_14_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[241] gpt_neox_layers_15_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[242] gpt_neox_layers_15_input_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[243] gpt_neox_layers_15_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[244] gpt_neox_layers_15_post_attention_layernorm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[245] gpt_neox_layers_15_attention_query_key_value_q_weight1: R.Tensor((6144, 256), dtype="uint32") = packed_params[246] gpt_neox_layers_15_attention_query_key_value_q_scale1: R.Tensor((6144, 64), dtype="float16") = packed_params[247] gpt_neox_layers_15_attention_query_key_value_bias1: R.Tensor((6144,), dtype="float16") = packed_params[248] gpt_neox_layers_15_attention_dense_q_weight1: R.Tensor((2048, 256), dtype="uint32") = packed_params[249] gpt_neox_layers_15_attention_dense_q_scale1: R.Tensor((2048, 64), dtype="float16") = packed_params[250] gpt_neox_layers_15_attention_dense_bias1: R.Tensor((2048,), dtype="float16") = packed_params[251] gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight1: R.Tensor((8192, 256), dtype="uint32") = packed_params[252] gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale1: R.Tensor((8192, 64), dtype="float16") = packed_params[253] gpt_neox_layers_15_mlp_dense_h_to_4h_bias1: R.Tensor((8192,), dtype="float32") = packed_params[254] gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight1: R.Tensor((2048, 1024), dtype="uint32") = packed_params[255] gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale1: R.Tensor((2048, 256), dtype="float16") = packed_params[256] gpt_neox_layers_15_mlp_dense_4h_to_h_bias1: R.Tensor((2048,), dtype="float32") = packed_params[257] gpt_neox_final_layer_norm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[258] gpt_neox_final_layer_norm_bias1: R.Tensor((2048,), dtype="float16") = packed_params[259] embed_out_q_weight1: R.Tensor((vocab_size, 256), dtype="uint32") = packed_params[260] embed_out_q_scale1: R.Tensor((vocab_size, 64), dtype="float16") = packed_params[261] layer_norm: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(input_embed, gpt_neox_layers_0_input_layernorm_weight1, gpt_neox_layers_0_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv1 = R.call_tir(cls.dequantize1, (gpt_neox_layers_0_attention_query_key_value_q_weight1, gpt_neox_layers_0_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv1, axes=None) matmul: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm, permute_dims, out_dtype="void") add: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul, gpt_neox_layers_0_attention_query_key_value_bias1) reshape: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add, R.shape([1, seq_len, 24, 256])) reshape1: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape, R.shape([seq_len, 24, 256])) lv2 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape1), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape2: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv2, R.shape([1, seq_len, 8, 256])) reshape3: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape2, R.shape([1, seq_len, 2048])) lv3 = R.call_tir(cls.dequantize2, (gpt_neox_layers_0_attention_dense_q_weight1, gpt_neox_layers_0_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims1: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv3, axes=None) matmul1: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape3, permute_dims1, out_dtype="void") add1: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul1, gpt_neox_layers_0_attention_dense_bias1) layer_norm1: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(input_embed, gpt_neox_layers_0_post_attention_layernorm_weight1, gpt_neox_layers_0_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv4 = R.call_tir(cls.dequantize3, (gpt_neox_layers_0_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_0_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims2: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv4, axes=None) matmul2: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm1, permute_dims2, out_dtype="float32") add2: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul2, gpt_neox_layers_0_mlp_dense_h_to_4h_bias1) gelu: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add2) astype: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu, dtype="float16") lv5 = R.call_tir(cls.dequantize4, (gpt_neox_layers_0_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_0_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims3: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv5, axes=None) matmul3: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype, permute_dims3, out_dtype="float32") add3: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul3, gpt_neox_layers_0_mlp_dense_4h_to_h_bias1) astype1: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add3, dtype="float16") add4: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype1, add1) add5: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add4, input_embed) layer_norm2: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add5, gpt_neox_layers_1_input_layernorm_weight1, gpt_neox_layers_1_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv6 = R.call_tir(cls.dequantize1, (gpt_neox_layers_1_attention_query_key_value_q_weight1, gpt_neox_layers_1_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims4: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv6, axes=None) matmul4: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm2, permute_dims4, out_dtype="void") add6: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul4, gpt_neox_layers_1_attention_query_key_value_bias1) reshape4: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add6, R.shape([1, seq_len, 24, 256])) reshape5: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape4, R.shape([seq_len, 24, 256])) lv7 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape5), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape6: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv7, R.shape([1, seq_len, 8, 256])) reshape7: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape6, R.shape([1, seq_len, 2048])) lv8 = R.call_tir(cls.dequantize2, (gpt_neox_layers_1_attention_dense_q_weight1, gpt_neox_layers_1_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims5: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv8, axes=None) matmul5: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape7, permute_dims5, out_dtype="void") add7: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul5, gpt_neox_layers_1_attention_dense_bias1) layer_norm3: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add5, gpt_neox_layers_1_post_attention_layernorm_weight1, gpt_neox_layers_1_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv9 = R.call_tir(cls.dequantize3, (gpt_neox_layers_1_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_1_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims6: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv9, axes=None) matmul6: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm3, permute_dims6, out_dtype="float32") add8: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul6, gpt_neox_layers_1_mlp_dense_h_to_4h_bias1) gelu1: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add8) astype2: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu1, dtype="float16") lv10 = R.call_tir(cls.dequantize4, (gpt_neox_layers_1_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_1_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims7: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv10, axes=None) matmul7: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype2, permute_dims7, out_dtype="float32") add9: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul7, gpt_neox_layers_1_mlp_dense_4h_to_h_bias1) astype3: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add9, dtype="float16") add10: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype3, add7) add11: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add10, add5) layer_norm4: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add11, gpt_neox_layers_2_input_layernorm_weight1, gpt_neox_layers_2_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv11 = R.call_tir(cls.dequantize1, (gpt_neox_layers_2_attention_query_key_value_q_weight1, gpt_neox_layers_2_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims8: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv11, axes=None) matmul8: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm4, permute_dims8, out_dtype="void") add12: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul8, gpt_neox_layers_2_attention_query_key_value_bias1) reshape8: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add12, R.shape([1, seq_len, 24, 256])) reshape9: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape8, R.shape([seq_len, 24, 256])) lv12 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape9), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape10: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv12, R.shape([1, seq_len, 8, 256])) reshape11: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape10, R.shape([1, seq_len, 2048])) lv13 = R.call_tir(cls.dequantize2, (gpt_neox_layers_2_attention_dense_q_weight1, gpt_neox_layers_2_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims9: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv13, axes=None) matmul9: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape11, permute_dims9, out_dtype="void") add13: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul9, gpt_neox_layers_2_attention_dense_bias1) layer_norm5: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add11, gpt_neox_layers_2_post_attention_layernorm_weight1, gpt_neox_layers_2_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv14 = R.call_tir(cls.dequantize3, (gpt_neox_layers_2_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_2_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims10: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv14, axes=None) matmul10: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm5, permute_dims10, out_dtype="float32") add14: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul10, gpt_neox_layers_2_mlp_dense_h_to_4h_bias1) gelu2: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add14) astype4: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu2, dtype="float16") lv15 = R.call_tir(cls.dequantize4, (gpt_neox_layers_2_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_2_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims11: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv15, axes=None) matmul11: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype4, permute_dims11, out_dtype="float32") add15: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul11, gpt_neox_layers_2_mlp_dense_4h_to_h_bias1) astype5: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add15, dtype="float16") add16: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype5, add13) add17: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add16, add11) layer_norm6: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add17, gpt_neox_layers_3_input_layernorm_weight1, gpt_neox_layers_3_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv16 = R.call_tir(cls.dequantize1, (gpt_neox_layers_3_attention_query_key_value_q_weight1, gpt_neox_layers_3_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims12: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv16, axes=None) matmul12: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm6, permute_dims12, out_dtype="void") add18: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul12, gpt_neox_layers_3_attention_query_key_value_bias1) reshape12: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add18, R.shape([1, seq_len, 24, 256])) reshape13: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape12, R.shape([seq_len, 24, 256])) lv17 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape13), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape14: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv17, R.shape([1, seq_len, 8, 256])) reshape15: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape14, R.shape([1, seq_len, 2048])) lv18 = R.call_tir(cls.dequantize2, (gpt_neox_layers_3_attention_dense_q_weight1, gpt_neox_layers_3_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims13: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv18, axes=None) matmul13: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape15, permute_dims13, out_dtype="void") add19: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul13, gpt_neox_layers_3_attention_dense_bias1) layer_norm7: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add17, gpt_neox_layers_3_post_attention_layernorm_weight1, gpt_neox_layers_3_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv19 = R.call_tir(cls.dequantize3, (gpt_neox_layers_3_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_3_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims14: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv19, axes=None) matmul14: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm7, permute_dims14, out_dtype="float32") add20: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul14, gpt_neox_layers_3_mlp_dense_h_to_4h_bias1) gelu3: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add20) astype6: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu3, dtype="float16") lv20 = R.call_tir(cls.dequantize4, (gpt_neox_layers_3_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_3_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims15: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv20, axes=None) matmul15: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype6, permute_dims15, out_dtype="float32") add21: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul15, gpt_neox_layers_3_mlp_dense_4h_to_h_bias1) astype7: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add21, dtype="float16") add22: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype7, add19) add23: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add22, add17) layer_norm8: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add23, gpt_neox_layers_4_input_layernorm_weight1, gpt_neox_layers_4_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv21 = R.call_tir(cls.dequantize1, (gpt_neox_layers_4_attention_query_key_value_q_weight1, gpt_neox_layers_4_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims16: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv21, axes=None) matmul16: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm8, permute_dims16, out_dtype="void") add24: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul16, gpt_neox_layers_4_attention_query_key_value_bias1) reshape16: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add24, R.shape([1, seq_len, 24, 256])) reshape17: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape16, R.shape([seq_len, 24, 256])) lv22 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape17), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape18: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv22, R.shape([1, seq_len, 8, 256])) reshape19: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape18, R.shape([1, seq_len, 2048])) lv23 = R.call_tir(cls.dequantize2, (gpt_neox_layers_4_attention_dense_q_weight1, gpt_neox_layers_4_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims17: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv23, axes=None) matmul17: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape19, permute_dims17, out_dtype="void") add25: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul17, gpt_neox_layers_4_attention_dense_bias1) layer_norm9: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add23, gpt_neox_layers_4_post_attention_layernorm_weight1, gpt_neox_layers_4_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv24 = R.call_tir(cls.dequantize3, (gpt_neox_layers_4_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_4_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims18: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv24, axes=None) matmul18: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm9, permute_dims18, out_dtype="float32") add26: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul18, gpt_neox_layers_4_mlp_dense_h_to_4h_bias1) gelu4: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add26) astype8: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu4, dtype="float16") lv25 = R.call_tir(cls.dequantize4, (gpt_neox_layers_4_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_4_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims19: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv25, axes=None) matmul19: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype8, permute_dims19, out_dtype="float32") add27: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul19, gpt_neox_layers_4_mlp_dense_4h_to_h_bias1) astype9: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add27, dtype="float16") add28: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype9, add25) add29: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add28, add23) layer_norm10: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add29, gpt_neox_layers_5_input_layernorm_weight1, gpt_neox_layers_5_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv26 = R.call_tir(cls.dequantize1, (gpt_neox_layers_5_attention_query_key_value_q_weight1, gpt_neox_layers_5_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims20: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv26, axes=None) matmul20: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm10, permute_dims20, out_dtype="void") add30: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul20, gpt_neox_layers_5_attention_query_key_value_bias1) reshape20: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add30, R.shape([1, seq_len, 24, 256])) reshape21: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape20, R.shape([seq_len, 24, 256])) lv27 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape21), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape22: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv27, R.shape([1, seq_len, 8, 256])) reshape23: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape22, R.shape([1, seq_len, 2048])) lv28 = R.call_tir(cls.dequantize2, (gpt_neox_layers_5_attention_dense_q_weight1, gpt_neox_layers_5_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims21: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv28, axes=None) matmul21: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape23, permute_dims21, out_dtype="void") add31: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul21, gpt_neox_layers_5_attention_dense_bias1) layer_norm11: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add29, gpt_neox_layers_5_post_attention_layernorm_weight1, gpt_neox_layers_5_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv29 = R.call_tir(cls.dequantize3, (gpt_neox_layers_5_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_5_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims22: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv29, axes=None) matmul22: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm11, permute_dims22, out_dtype="float32") add32: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul22, gpt_neox_layers_5_mlp_dense_h_to_4h_bias1) gelu5: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add32) astype10: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu5, dtype="float16") lv30 = R.call_tir(cls.dequantize4, (gpt_neox_layers_5_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_5_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims23: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv30, axes=None) matmul23: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype10, permute_dims23, out_dtype="float32") add33: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul23, gpt_neox_layers_5_mlp_dense_4h_to_h_bias1) astype11: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add33, dtype="float16") add34: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype11, add31) add35: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add34, add29) layer_norm12: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add35, gpt_neox_layers_6_input_layernorm_weight1, gpt_neox_layers_6_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv31 = R.call_tir(cls.dequantize1, (gpt_neox_layers_6_attention_query_key_value_q_weight1, gpt_neox_layers_6_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims24: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv31, axes=None) matmul24: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm12, permute_dims24, out_dtype="void") add36: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul24, gpt_neox_layers_6_attention_query_key_value_bias1) reshape24: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add36, R.shape([1, seq_len, 24, 256])) reshape25: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape24, R.shape([seq_len, 24, 256])) lv32 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape25), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape26: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv32, R.shape([1, seq_len, 8, 256])) reshape27: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape26, R.shape([1, seq_len, 2048])) lv33 = R.call_tir(cls.dequantize2, (gpt_neox_layers_6_attention_dense_q_weight1, gpt_neox_layers_6_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims25: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv33, axes=None) matmul25: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape27, permute_dims25, out_dtype="void") add37: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul25, gpt_neox_layers_6_attention_dense_bias1) layer_norm13: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add35, gpt_neox_layers_6_post_attention_layernorm_weight1, gpt_neox_layers_6_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv34 = R.call_tir(cls.dequantize3, (gpt_neox_layers_6_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_6_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims26: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv34, axes=None) matmul26: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm13, permute_dims26, out_dtype="float32") add38: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul26, gpt_neox_layers_6_mlp_dense_h_to_4h_bias1) gelu6: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add38) astype12: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu6, dtype="float16") lv35 = R.call_tir(cls.dequantize4, (gpt_neox_layers_6_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_6_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims27: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv35, axes=None) matmul27: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype12, permute_dims27, out_dtype="float32") add39: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul27, gpt_neox_layers_6_mlp_dense_4h_to_h_bias1) astype13: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add39, dtype="float16") add40: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype13, add37) add41: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add40, add35) layer_norm14: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add41, gpt_neox_layers_7_input_layernorm_weight1, gpt_neox_layers_7_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv36 = R.call_tir(cls.dequantize1, (gpt_neox_layers_7_attention_query_key_value_q_weight1, gpt_neox_layers_7_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims28: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv36, axes=None) matmul28: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm14, permute_dims28, out_dtype="void") add42: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul28, gpt_neox_layers_7_attention_query_key_value_bias1) reshape28: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add42, R.shape([1, seq_len, 24, 256])) reshape29: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape28, R.shape([seq_len, 24, 256])) lv37 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape29), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape30: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv37, R.shape([1, seq_len, 8, 256])) reshape31: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape30, R.shape([1, seq_len, 2048])) lv38 = R.call_tir(cls.dequantize2, (gpt_neox_layers_7_attention_dense_q_weight1, gpt_neox_layers_7_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims29: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv38, axes=None) matmul29: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape31, permute_dims29, out_dtype="void") add43: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul29, gpt_neox_layers_7_attention_dense_bias1) layer_norm15: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add41, gpt_neox_layers_7_post_attention_layernorm_weight1, gpt_neox_layers_7_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv39 = R.call_tir(cls.dequantize3, (gpt_neox_layers_7_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_7_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims30: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv39, axes=None) matmul30: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm15, permute_dims30, out_dtype="float32") add44: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul30, gpt_neox_layers_7_mlp_dense_h_to_4h_bias1) gelu7: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add44) astype14: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu7, dtype="float16") lv40 = R.call_tir(cls.dequantize4, (gpt_neox_layers_7_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_7_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims31: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv40, axes=None) matmul31: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype14, permute_dims31, out_dtype="float32") add45: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul31, gpt_neox_layers_7_mlp_dense_4h_to_h_bias1) astype15: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add45, dtype="float16") add46: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype15, add43) add47: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add46, add41) layer_norm16: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add47, gpt_neox_layers_8_input_layernorm_weight1, gpt_neox_layers_8_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv41 = R.call_tir(cls.dequantize1, (gpt_neox_layers_8_attention_query_key_value_q_weight1, gpt_neox_layers_8_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims32: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv41, axes=None) matmul32: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm16, permute_dims32, out_dtype="void") add48: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul32, gpt_neox_layers_8_attention_query_key_value_bias1) reshape32: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add48, R.shape([1, seq_len, 24, 256])) reshape33: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape32, R.shape([seq_len, 24, 256])) lv42 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape33), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape34: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv42, R.shape([1, seq_len, 8, 256])) reshape35: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape34, R.shape([1, seq_len, 2048])) lv43 = R.call_tir(cls.dequantize2, (gpt_neox_layers_8_attention_dense_q_weight1, gpt_neox_layers_8_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims33: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv43, axes=None) matmul33: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape35, permute_dims33, out_dtype="void") add49: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul33, gpt_neox_layers_8_attention_dense_bias1) layer_norm17: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add47, gpt_neox_layers_8_post_attention_layernorm_weight1, gpt_neox_layers_8_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv44 = R.call_tir(cls.dequantize3, (gpt_neox_layers_8_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_8_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims34: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv44, axes=None) matmul34: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm17, permute_dims34, out_dtype="float32") add50: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul34, gpt_neox_layers_8_mlp_dense_h_to_4h_bias1) gelu8: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add50) astype16: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu8, dtype="float16") lv45 = R.call_tir(cls.dequantize4, (gpt_neox_layers_8_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_8_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims35: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv45, axes=None) matmul35: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype16, permute_dims35, out_dtype="float32") add51: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul35, gpt_neox_layers_8_mlp_dense_4h_to_h_bias1) astype17: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add51, dtype="float16") add52: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype17, add49) add53: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add52, add47) layer_norm18: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add53, gpt_neox_layers_9_input_layernorm_weight1, gpt_neox_layers_9_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv46 = R.call_tir(cls.dequantize1, (gpt_neox_layers_9_attention_query_key_value_q_weight1, gpt_neox_layers_9_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims36: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv46, axes=None) matmul36: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm18, permute_dims36, out_dtype="void") add54: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul36, gpt_neox_layers_9_attention_query_key_value_bias1) reshape36: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add54, R.shape([1, seq_len, 24, 256])) reshape37: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape36, R.shape([seq_len, 24, 256])) lv47 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape37), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape38: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv47, R.shape([1, seq_len, 8, 256])) reshape39: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape38, R.shape([1, seq_len, 2048])) lv48 = R.call_tir(cls.dequantize2, (gpt_neox_layers_9_attention_dense_q_weight1, gpt_neox_layers_9_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims37: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv48, axes=None) matmul37: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape39, permute_dims37, out_dtype="void") add55: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul37, gpt_neox_layers_9_attention_dense_bias1) layer_norm19: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add53, gpt_neox_layers_9_post_attention_layernorm_weight1, gpt_neox_layers_9_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv49 = R.call_tir(cls.dequantize3, (gpt_neox_layers_9_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_9_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims38: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv49, axes=None) matmul38: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm19, permute_dims38, out_dtype="float32") add56: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul38, gpt_neox_layers_9_mlp_dense_h_to_4h_bias1) gelu9: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add56) astype18: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu9, dtype="float16") lv50 = R.call_tir(cls.dequantize4, (gpt_neox_layers_9_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_9_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims39: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv50, axes=None) matmul39: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype18, permute_dims39, out_dtype="float32") add57: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul39, gpt_neox_layers_9_mlp_dense_4h_to_h_bias1) astype19: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add57, dtype="float16") add58: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype19, add55) add59: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add58, add53) layer_norm20: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add59, gpt_neox_layers_10_input_layernorm_weight1, gpt_neox_layers_10_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv51 = R.call_tir(cls.dequantize1, (gpt_neox_layers_10_attention_query_key_value_q_weight1, gpt_neox_layers_10_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims40: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv51, axes=None) matmul40: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm20, permute_dims40, out_dtype="void") add60: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul40, gpt_neox_layers_10_attention_query_key_value_bias1) reshape40: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add60, R.shape([1, seq_len, 24, 256])) reshape41: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape40, R.shape([seq_len, 24, 256])) lv52 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape41), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape42: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv52, R.shape([1, seq_len, 8, 256])) reshape43: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape42, R.shape([1, seq_len, 2048])) lv53 = R.call_tir(cls.dequantize2, (gpt_neox_layers_10_attention_dense_q_weight1, gpt_neox_layers_10_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims41: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv53, axes=None) matmul41: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape43, permute_dims41, out_dtype="void") add61: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul41, gpt_neox_layers_10_attention_dense_bias1) layer_norm21: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add59, gpt_neox_layers_10_post_attention_layernorm_weight1, gpt_neox_layers_10_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv54 = R.call_tir(cls.dequantize3, (gpt_neox_layers_10_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_10_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims42: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv54, axes=None) matmul42: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm21, permute_dims42, out_dtype="float32") add62: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul42, gpt_neox_layers_10_mlp_dense_h_to_4h_bias1) gelu10: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add62) astype20: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu10, dtype="float16") lv55 = R.call_tir(cls.dequantize4, (gpt_neox_layers_10_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_10_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims43: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv55, axes=None) matmul43: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype20, permute_dims43, out_dtype="float32") add63: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul43, gpt_neox_layers_10_mlp_dense_4h_to_h_bias1) astype21: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add63, dtype="float16") add64: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype21, add61) add65: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add64, add59) layer_norm22: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add65, gpt_neox_layers_11_input_layernorm_weight1, gpt_neox_layers_11_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv56 = R.call_tir(cls.dequantize1, (gpt_neox_layers_11_attention_query_key_value_q_weight1, gpt_neox_layers_11_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims44: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv56, axes=None) matmul44: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm22, permute_dims44, out_dtype="void") add66: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul44, gpt_neox_layers_11_attention_query_key_value_bias1) reshape44: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add66, R.shape([1, seq_len, 24, 256])) reshape45: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape44, R.shape([seq_len, 24, 256])) lv57 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape45), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape46: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv57, R.shape([1, seq_len, 8, 256])) reshape47: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape46, R.shape([1, seq_len, 2048])) lv58 = R.call_tir(cls.dequantize2, (gpt_neox_layers_11_attention_dense_q_weight1, gpt_neox_layers_11_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims45: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv58, axes=None) matmul45: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape47, permute_dims45, out_dtype="void") add67: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul45, gpt_neox_layers_11_attention_dense_bias1) layer_norm23: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add65, gpt_neox_layers_11_post_attention_layernorm_weight1, gpt_neox_layers_11_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv59 = R.call_tir(cls.dequantize3, (gpt_neox_layers_11_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_11_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims46: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv59, axes=None) matmul46: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm23, permute_dims46, out_dtype="float32") add68: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul46, gpt_neox_layers_11_mlp_dense_h_to_4h_bias1) gelu11: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add68) astype22: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu11, dtype="float16") lv60 = R.call_tir(cls.dequantize4, (gpt_neox_layers_11_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_11_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims47: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv60, axes=None) matmul47: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype22, permute_dims47, out_dtype="float32") add69: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul47, gpt_neox_layers_11_mlp_dense_4h_to_h_bias1) astype23: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add69, dtype="float16") add70: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype23, add67) add71: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add70, add65) layer_norm24: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add71, gpt_neox_layers_12_input_layernorm_weight1, gpt_neox_layers_12_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv61 = R.call_tir(cls.dequantize1, (gpt_neox_layers_12_attention_query_key_value_q_weight1, gpt_neox_layers_12_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims48: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv61, axes=None) matmul48: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm24, permute_dims48, out_dtype="void") add72: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul48, gpt_neox_layers_12_attention_query_key_value_bias1) reshape48: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add72, R.shape([1, seq_len, 24, 256])) reshape49: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape48, R.shape([seq_len, 24, 256])) lv62 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape49), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape50: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv62, R.shape([1, seq_len, 8, 256])) reshape51: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape50, R.shape([1, seq_len, 2048])) lv63 = R.call_tir(cls.dequantize2, (gpt_neox_layers_12_attention_dense_q_weight1, gpt_neox_layers_12_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims49: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv63, axes=None) matmul49: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape51, permute_dims49, out_dtype="void") add73: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul49, gpt_neox_layers_12_attention_dense_bias1) layer_norm25: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add71, gpt_neox_layers_12_post_attention_layernorm_weight1, gpt_neox_layers_12_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv64 = R.call_tir(cls.dequantize3, (gpt_neox_layers_12_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_12_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims50: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv64, axes=None) matmul50: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm25, permute_dims50, out_dtype="float32") add74: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul50, gpt_neox_layers_12_mlp_dense_h_to_4h_bias1) gelu12: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add74) astype24: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu12, dtype="float16") lv65 = R.call_tir(cls.dequantize4, (gpt_neox_layers_12_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_12_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims51: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv65, axes=None) matmul51: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype24, permute_dims51, out_dtype="float32") add75: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul51, gpt_neox_layers_12_mlp_dense_4h_to_h_bias1) astype25: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add75, dtype="float16") add76: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype25, add73) add77: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add76, add71) layer_norm26: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add77, gpt_neox_layers_13_input_layernorm_weight1, gpt_neox_layers_13_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv66 = R.call_tir(cls.dequantize1, (gpt_neox_layers_13_attention_query_key_value_q_weight1, gpt_neox_layers_13_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims52: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv66, axes=None) matmul52: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm26, permute_dims52, out_dtype="void") add78: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul52, gpt_neox_layers_13_attention_query_key_value_bias1) reshape52: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add78, R.shape([1, seq_len, 24, 256])) reshape53: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape52, R.shape([seq_len, 24, 256])) lv67 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape53), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape54: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv67, R.shape([1, seq_len, 8, 256])) reshape55: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape54, R.shape([1, seq_len, 2048])) lv68 = R.call_tir(cls.dequantize2, (gpt_neox_layers_13_attention_dense_q_weight1, gpt_neox_layers_13_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims53: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv68, axes=None) matmul53: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape55, permute_dims53, out_dtype="void") add79: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul53, gpt_neox_layers_13_attention_dense_bias1) layer_norm27: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add77, gpt_neox_layers_13_post_attention_layernorm_weight1, gpt_neox_layers_13_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv69 = R.call_tir(cls.dequantize3, (gpt_neox_layers_13_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_13_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims54: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv69, axes=None) matmul54: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm27, permute_dims54, out_dtype="float32") add80: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul54, gpt_neox_layers_13_mlp_dense_h_to_4h_bias1) gelu13: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add80) astype26: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu13, dtype="float16") lv70 = R.call_tir(cls.dequantize4, (gpt_neox_layers_13_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_13_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims55: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv70, axes=None) matmul55: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype26, permute_dims55, out_dtype="float32") add81: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul55, gpt_neox_layers_13_mlp_dense_4h_to_h_bias1) astype27: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add81, dtype="float16") add82: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype27, add79) add83: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add82, add77) layer_norm28: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add83, gpt_neox_layers_14_input_layernorm_weight1, gpt_neox_layers_14_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv71 = R.call_tir(cls.dequantize1, (gpt_neox_layers_14_attention_query_key_value_q_weight1, gpt_neox_layers_14_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims56: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv71, axes=None) matmul56: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm28, permute_dims56, out_dtype="void") add84: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul56, gpt_neox_layers_14_attention_query_key_value_bias1) reshape56: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add84, R.shape([1, seq_len, 24, 256])) reshape57: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape56, R.shape([seq_len, 24, 256])) lv72 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape57), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape58: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv72, R.shape([1, seq_len, 8, 256])) reshape59: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape58, R.shape([1, seq_len, 2048])) lv73 = R.call_tir(cls.dequantize2, (gpt_neox_layers_14_attention_dense_q_weight1, gpt_neox_layers_14_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims57: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv73, axes=None) matmul57: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape59, permute_dims57, out_dtype="void") add85: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul57, gpt_neox_layers_14_attention_dense_bias1) layer_norm29: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add83, gpt_neox_layers_14_post_attention_layernorm_weight1, gpt_neox_layers_14_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv74 = R.call_tir(cls.dequantize3, (gpt_neox_layers_14_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_14_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims58: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv74, axes=None) matmul58: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm29, permute_dims58, out_dtype="float32") add86: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul58, gpt_neox_layers_14_mlp_dense_h_to_4h_bias1) gelu14: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add86) astype28: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu14, dtype="float16") lv75 = R.call_tir(cls.dequantize4, (gpt_neox_layers_14_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_14_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims59: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv75, axes=None) matmul59: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype28, permute_dims59, out_dtype="float32") add87: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul59, gpt_neox_layers_14_mlp_dense_4h_to_h_bias1) astype29: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add87, dtype="float16") add88: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype29, add85) add89: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add88, add83) layer_norm30: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add89, gpt_neox_layers_15_input_layernorm_weight1, gpt_neox_layers_15_input_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv76 = R.call_tir(cls.dequantize1, (gpt_neox_layers_15_attention_query_key_value_q_weight1, gpt_neox_layers_15_attention_query_key_value_q_scale1), out_sinfo=R.Tensor((6144, 2048), dtype="float16")) permute_dims60: R.Tensor((2048, 6144), dtype="float16") = R.permute_dims(lv76, axes=None) matmul60: R.Tensor((1, seq_len, 6144), dtype="float16") = R.matmul(layer_norm30, permute_dims60, out_dtype="void") add90: R.Tensor((1, seq_len, 6144), dtype="float16") = R.add(matmul60, gpt_neox_layers_15_attention_query_key_value_bias1) reshape60: R.Tensor((1, seq_len, 24, 256), dtype="float16") = R.reshape(add90, R.shape([1, seq_len, 24, 256])) reshape61: R.Tensor((seq_len, 24, 256), dtype="float16") = R.reshape(reshape60, R.shape([seq_len, 24, 256])) lv77 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape61), out_sinfo=R.Tensor((seq_len, 8, 256), dtype="float16")) reshape62: R.Tensor((1, seq_len, 8, 256), dtype="float16") = R.reshape(lv77, R.shape([1, seq_len, 8, 256])) reshape63: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape62, R.shape([1, seq_len, 2048])) lv78 = R.call_tir(cls.dequantize2, (gpt_neox_layers_15_attention_dense_q_weight1, gpt_neox_layers_15_attention_dense_q_scale1), out_sinfo=R.Tensor((2048, 2048), dtype="float16")) permute_dims61: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv78, axes=None) matmul61: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape63, permute_dims61, out_dtype="void") add91: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul61, gpt_neox_layers_15_attention_dense_bias1) layer_norm31: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add89, gpt_neox_layers_15_post_attention_layernorm_weight1, gpt_neox_layers_15_post_attention_layernorm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv79 = R.call_tir(cls.dequantize3, (gpt_neox_layers_15_mlp_dense_h_to_4h_q_weight1, gpt_neox_layers_15_mlp_dense_h_to_4h_q_scale1), out_sinfo=R.Tensor((8192, 2048), dtype="float16")) permute_dims62: R.Tensor((2048, 8192), dtype="float16") = R.permute_dims(lv79, axes=None) matmul62: R.Tensor((1, seq_len, 8192), dtype="float32") = R.matmul(layer_norm31, permute_dims62, out_dtype="float32") add92: R.Tensor((1, seq_len, 8192), dtype="float32") = R.add(matmul62, gpt_neox_layers_15_mlp_dense_h_to_4h_bias1) gelu15: R.Tensor((1, seq_len, 8192), dtype="float32") = R.nn.gelu(add92) astype30: R.Tensor((1, seq_len, 8192), dtype="float16") = R.astype(gelu15, dtype="float16") lv80 = R.call_tir(cls.dequantize4, (gpt_neox_layers_15_mlp_dense_4h_to_h_q_weight1, gpt_neox_layers_15_mlp_dense_4h_to_h_q_scale1), out_sinfo=R.Tensor((2048, 8192), dtype="float16")) permute_dims63: R.Tensor((8192, 2048), dtype="float16") = R.permute_dims(lv80, axes=None) matmul63: R.Tensor((1, seq_len, 2048), dtype="float32") = R.matmul(astype30, permute_dims63, out_dtype="float32") add93: R.Tensor((1, seq_len, 2048), dtype="float32") = R.add(matmul63, gpt_neox_layers_15_mlp_dense_4h_to_h_bias1) astype31: R.Tensor((1, seq_len, 2048), dtype="float16") = R.astype(add93, dtype="float16") add94: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(astype31, add91) add95: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(add94, add89) layer_norm32: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.layer_norm(add95, gpt_neox_final_layer_norm_weight1, gpt_neox_final_layer_norm_bias1, axes=[-1], epsilon=1.0000000000000001e-05, center=True, scale=True) lv81 = R.call_tir(cls.index, (layer_norm32,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16")) lv82 = R.call_tir(cls.dequantize, (embed_out_q_weight1, embed_out_q_scale1), out_sinfo=R.Tensor((vocab_size, 2048), dtype="float16")) permute_dims64: R.Tensor((2048, vocab_size), dtype="float16") = R.permute_dims(lv82, axes=None) matmul64: R.Tensor((1, 1, vocab_size), dtype="float16") = R.matmul(lv81, permute_dims64, out_dtype="void") astype32: R.Tensor((1, 1, vocab_size), dtype="float32") = R.astype(matmul64, dtype="float32") gv1: R.Tuple(R.Tensor((1, 1, vocab_size), dtype="float32"), R.Object) = astype32, paged_kv_cache R.output(gv1) return gv1 @R.function def softmax_with_temperature(logits: R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), temperature: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"): batch_size = T.int64(is_size_var=True) vocab_size = T.int64(is_size_var=True) R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 128, "seq_len": 2048, "total_seq_len": 2048}}) cls = Module with R.dataflow(): lv: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", logits, R.shape([batch_size, vocab_size]), sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),)) lv1 = R.call_tir(cls.chunk_lse, (lv, temperature), out_sinfo=[R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32"), R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32")]) lv2: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[0] lv3: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[1] lv4 = R.call_tir(cls.softmax_with_chunked_sum, (lv, temperature, lv2, lv3), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32")) gv: R.Tensor((batch_size, 1, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", lv4, R.shape([batch_size, 1, vocab_size]), sinfo_args=(R.Tensor((batch_size, 1, vocab_size), dtype="float32"),)) R.output(gv) return gv